use std::sync::{
Arc,
Mutex,
};
use qubit_function::{
Callable,
Runnable,
};
use rayon::ThreadPool as RayonThreadPool;
use qubit_executor::{
TaskHandle,
task::spi::{
TaskEndpointPair,
TaskRunner,
TaskSlot,
},
};
use qubit_executor::service::{
ExecutorService,
ExecutorServiceLifecycle,
StopReport,
SubmissionError,
};
use crate::{
pending_cancel::PendingCancel,
rayon_executor_service_build_error::RayonExecutorServiceBuildError,
rayon_executor_service_builder::RayonExecutorServiceBuilder,
rayon_executor_service_state::RayonExecutorServiceState,
rayon_task_handle::RayonTaskHandle,
};
#[derive(Clone)]
pub struct RayonExecutorService {
pub(crate) pool: Arc<RayonThreadPool>,
pub(crate) state: Arc<RayonExecutorServiceState>,
}
impl RayonExecutorService {
#[inline]
pub fn new() -> Result<Self, RayonExecutorServiceBuildError> {
Self::builder().build()
}
#[inline]
pub fn builder() -> RayonExecutorServiceBuilder {
RayonExecutorServiceBuilder::default()
}
fn submit_callable_with<C, R, E, H, F>(
&self,
task: C,
split: F,
) -> Result<(H, usize, PendingCancel), SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
F: FnOnce(TaskEndpointPair<R, E>) -> (H, TaskSlot<R, E>),
{
let submission_guard = self.state.lock_submission();
if self.state.is_not_running() {
return Err(SubmissionError::Shutdown);
}
let task_id = self.state.next_task_id();
self.state.on_task_accepted();
let (handle, completion) = split(TaskEndpointPair::new());
completion.accept();
let completion = Arc::new(Mutex::new(Some(completion)));
let completion_for_cancel = Arc::clone(&completion);
let cancel: PendingCancel = Arc::new(move || {
let completion = completion_for_cancel
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
completion.is_some_and(|completion| completion.cancel_unstarted())
});
self.state
.register_pending_task(task_id, Arc::clone(&cancel));
drop(submission_guard);
let completion_for_run = completion;
let state_for_run = Arc::clone(&self.state);
self.pool.spawn_fifo(move || {
if !state_for_run.start_pending_task(task_id, || true) {
return;
}
let completion = completion_for_run
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
if let Some(completion) = completion {
TaskRunner::new(task).run(completion);
}
state_for_run.on_task_completed();
});
Ok((handle, task_id, cancel))
}
}
impl ExecutorService for RayonExecutorService {
type ResultHandle<R, E>
= TaskHandle<R, E>
where
R: Send + 'static,
E: Send + 'static;
type TrackedHandle<R, E>
= RayonTaskHandle<R, E>
where
R: Send + 'static,
E: Send + 'static;
fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
where
T: Runnable<E> + Send + 'static,
E: Send + 'static,
{
let submission_guard = self.state.lock_submission();
if self.state.is_not_running() {
return Err(SubmissionError::Shutdown);
}
let task_id = self.state.next_task_id();
self.state.on_task_accepted();
let cancel: PendingCancel = Arc::new(|| true);
self.state
.register_pending_task(task_id, Arc::clone(&cancel));
drop(submission_guard);
let state_for_run = Arc::clone(&self.state);
self.pool.spawn_fifo(move || {
if !state_for_run.start_pending_task(task_id, || true) {
return;
}
let mut task = task;
let _ignored = TaskRunner::new(move || task.run()).call::<(), E>();
state_for_run.on_task_completed();
});
Ok(())
}
fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let (handle, _, _) = self.submit_callable_with(task, TaskEndpointPair::into_parts)?;
Ok(handle)
}
fn submit_tracked_callable<C, R, E>(
&self,
task: C,
) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let (handle, task_id, cancel) =
self.submit_callable_with(task, TaskEndpointPair::into_tracked_parts)?;
Ok(RayonTaskHandle::new(
handle,
task_id,
Arc::clone(&self.state),
cancel,
))
}
fn shutdown(&self) {
let _guard = self.state.lock_submission();
self.state.shutdown();
self.state.notify_if_terminated();
}
fn stop(&self) -> StopReport {
let _guard = self.state.lock_submission();
self.state.stop();
self.state.cancel_pending_tasks_for_stop()
}
fn lifecycle(&self) -> ExecutorServiceLifecycle {
self.state.lifecycle()
}
fn is_not_running(&self) -> bool {
self.state.is_not_running()
}
fn is_terminated(&self) -> bool {
self.lifecycle() == ExecutorServiceLifecycle::Terminated
}
fn wait_termination(&self) {
self.state.wait_for_termination();
}
}