1use std::sync::{
11 Arc,
12 Mutex,
13};
14
15use qubit_function::{
16 Callable,
17 Runnable,
18};
19use rayon::ThreadPool as RayonThreadPool;
20
21use qubit_executor::{
22 TaskHandle,
23 task::spi::{
24 TaskEndpointPair,
25 TaskRunner,
26 TaskSlot,
27 },
28};
29
30use qubit_executor::service::{
31 ExecutorService,
32 ExecutorServiceLifecycle,
33 StopReport,
34 SubmissionError,
35};
36
37use crate::{
38 pending_cancel::PendingCancel,
39 rayon_executor_service_build_error::RayonExecutorServiceBuildError,
40 rayon_executor_service_builder::RayonExecutorServiceBuilder,
41 rayon_executor_service_state::RayonExecutorServiceState,
42 rayon_task_handle::RayonTaskHandle,
43};
44
45#[derive(Clone)]
51pub struct RayonExecutorService {
52 pub(crate) pool: Arc<RayonThreadPool>,
54 pub(crate) state: Arc<RayonExecutorServiceState>,
56}
57
58impl RayonExecutorService {
59 #[inline]
71 pub fn new() -> Result<Self, RayonExecutorServiceBuildError> {
72 Self::builder().build()
73 }
74
75 #[inline]
81 pub fn builder() -> RayonExecutorServiceBuilder {
82 RayonExecutorServiceBuilder::default()
83 }
84
85 fn submit_callable_with<C, R, E, H, F>(
103 &self,
104 task: C,
105 split: F,
106 ) -> Result<(H, usize, PendingCancel), SubmissionError>
107 where
108 C: Callable<R, E> + Send + 'static,
109 R: Send + 'static,
110 E: Send + 'static,
111 F: FnOnce(TaskEndpointPair<R, E>) -> (H, TaskSlot<R, E>),
112 {
113 let submission_guard = self.state.lock_submission();
114 if self.state.is_not_running() {
115 return Err(SubmissionError::Shutdown);
116 }
117 let task_id = self.state.next_task_id();
118 self.state.on_task_accepted();
119 let (handle, completion) = split(TaskEndpointPair::new());
120 completion.accept();
121 let completion = Arc::new(Mutex::new(Some(completion)));
122 let completion_for_cancel = Arc::clone(&completion);
123 let cancel: PendingCancel = Arc::new(move || {
124 let completion = completion_for_cancel
125 .lock()
126 .unwrap_or_else(std::sync::PoisonError::into_inner)
127 .take();
128 completion.is_some_and(|completion| completion.cancel_unstarted())
129 });
130 self.state
131 .register_pending_task(task_id, Arc::clone(&cancel));
132 drop(submission_guard);
133
134 let completion_for_run = completion;
135 let state_for_run = Arc::clone(&self.state);
136 self.pool.spawn_fifo(move || {
137 let mut running_completion = None;
138 if !state_for_run.start_pending_task(task_id, || {
139 let mut completion = completion_for_run
140 .lock()
141 .unwrap_or_else(std::sync::PoisonError::into_inner);
142 let Some(task_completion) = completion.take() else {
143 return false;
144 };
145 match task_completion.try_start() {
146 Ok(running) => {
147 running_completion = Some(running);
148 true
149 }
150 Err(task_completion) => {
151 *completion = Some(task_completion);
152 false
153 }
154 }
155 }) {
156 return;
157 }
158 let running_completion =
159 running_completion.expect("claimed pending task should own a running slot");
160 TaskRunner::new(task).run_started(running_completion);
161 state_for_run.on_task_completed();
162 });
163 Ok((handle, task_id, cancel))
164 }
165}
166
167impl ExecutorService for RayonExecutorService {
168 type ResultHandle<R, E>
169 = TaskHandle<R, E>
170 where
171 R: Send + 'static,
172 E: Send + 'static;
173
174 type TrackedHandle<R, E>
175 = RayonTaskHandle<R, E>
176 where
177 R: Send + 'static,
178 E: Send + 'static;
179
180 fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
182 where
183 T: Runnable<E> + Send + 'static,
184 E: Send + 'static,
185 {
186 let submission_guard = self.state.lock_submission();
187 if self.state.is_not_running() {
188 return Err(SubmissionError::Shutdown);
189 }
190 let task_id = self.state.next_task_id();
191 self.state.on_task_accepted();
192 let cancel: PendingCancel = Arc::new(|| true);
193 self.state
194 .register_pending_task(task_id, Arc::clone(&cancel));
195 drop(submission_guard);
196
197 let state_for_run = Arc::clone(&self.state);
198 self.pool.spawn_fifo(move || {
199 if !state_for_run.start_pending_task(task_id, || true) {
200 return;
201 }
202 let mut task = task;
203 let _ignored = TaskRunner::new(move || task.run()).call::<(), E>();
204 state_for_run.on_task_completed();
205 });
206 Ok(())
207 }
208
209 fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
224 where
225 C: Callable<R, E> + Send + 'static,
226 R: Send + 'static,
227 E: Send + 'static,
228 {
229 let (handle, _, _) = self.submit_callable_with(task, TaskEndpointPair::into_parts)?;
230 Ok(handle)
231 }
232
233 fn submit_tracked_callable<C, R, E>(
235 &self,
236 task: C,
237 ) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
238 where
239 C: Callable<R, E> + Send + 'static,
240 R: Send + 'static,
241 E: Send + 'static,
242 {
243 let (handle, task_id, cancel) =
244 self.submit_callable_with(task, TaskEndpointPair::into_tracked_parts)?;
245 Ok(RayonTaskHandle::new(
246 handle,
247 task_id,
248 Arc::clone(&self.state),
249 cancel,
250 ))
251 }
252
253 fn shutdown(&self) {
257 let _guard = self.state.lock_submission();
258 self.state.shutdown();
259 self.state.notify_if_terminated();
260 }
261
262 fn stop(&self) -> StopReport {
274 let _guard = self.state.lock_submission();
275 self.state.stop();
276 self.state.cancel_pending_tasks_for_stop()
277 }
278
279 fn lifecycle(&self) -> ExecutorServiceLifecycle {
281 self.state.lifecycle()
282 }
283
284 fn is_not_running(&self) -> bool {
286 self.state.is_not_running()
287 }
288
289 fn is_terminated(&self) -> bool {
291 self.lifecycle() == ExecutorServiceLifecycle::Terminated
292 }
293
294 fn wait_termination(&self) {
296 self.state.wait_for_termination();
297 }
298}