qubit_rayon_executor/
rayon_executor_service.rs1use std::{
11 sync::{
12 Arc,
13 Mutex,
14 },
15 thread,
16 time::Duration,
17};
18
19use qubit_function::{
20 Callable,
21 Runnable,
22};
23use rayon::ThreadPool as RayonThreadPool;
24
25use qubit_executor::{
26 TaskHandle,
27 task::spi::{
28 TaskEndpointPair,
29 TaskRunner,
30 },
31};
32
33use qubit_executor::service::{
34 ExecutorService,
35 ExecutorServiceLifecycle,
36 StopReport,
37 SubmissionError,
38};
39
40use crate::{
41 pending_cancel::PendingCancel,
42 rayon_executor_service_build_error::RayonExecutorServiceBuildError,
43 rayon_executor_service_builder::RayonExecutorServiceBuilder,
44 rayon_executor_service_state::RayonExecutorServiceState,
45 rayon_task_handle::RayonTaskHandle,
46};
47
48#[derive(Clone)]
54pub struct RayonExecutorService {
55 pub(crate) pool: Arc<RayonThreadPool>,
57 pub(crate) state: Arc<RayonExecutorServiceState>,
59}
60
61impl RayonExecutorService {
62 #[inline]
74 pub fn new() -> Result<Self, RayonExecutorServiceBuildError> {
75 Self::builder().build()
76 }
77
78 #[inline]
84 pub fn builder() -> RayonExecutorServiceBuilder {
85 RayonExecutorServiceBuilder::default()
86 }
87}
88
89impl ExecutorService for RayonExecutorService {
90 type ResultHandle<R, E>
91 = TaskHandle<R, E>
92 where
93 R: Send + 'static,
94 E: Send + 'static;
95
96 type TrackedHandle<R, E>
97 = RayonTaskHandle<R, E>
98 where
99 R: Send + 'static,
100 E: Send + 'static;
101
102 fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
104 where
105 T: Runnable<E> + Send + 'static,
106 E: Send + 'static,
107 {
108 let submission_guard = self.state.lock_submission();
109 if self.state.is_not_running() {
110 return Err(SubmissionError::Shutdown);
111 }
112 let task_id = self.state.next_task_id();
113 self.state.on_task_accepted();
114 let cancel: PendingCancel = Arc::new(|| true);
115 self.state
116 .register_pending_task(task_id, Arc::clone(&cancel));
117 drop(submission_guard);
118
119 let state_for_run = Arc::clone(&self.state);
120 self.pool.spawn_fifo(move || {
121 if !state_for_run.start_pending_task(task_id, || true) {
122 return;
123 }
124 let mut task = task;
125 let _ignored = TaskRunner::new(move || task.run()).call::<(), E>();
126 state_for_run.on_task_completed();
127 });
128 Ok(())
129 }
130
131 fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
146 where
147 C: Callable<R, E> + Send + 'static,
148 R: Send + 'static,
149 E: Send + 'static,
150 {
151 let submission_guard = self.state.lock_submission();
152 if self.state.is_not_running() {
153 return Err(SubmissionError::Shutdown);
154 }
155 let task_id = self.state.next_task_id();
156 self.state.on_task_accepted();
157 let (handle, completion) = TaskEndpointPair::new().into_parts();
158 completion.accept();
159 let completion = Arc::new(Mutex::new(Some(completion)));
160 let completion_for_cancel = Arc::clone(&completion);
161 let cancel: PendingCancel = Arc::new(move || {
162 let completion = completion_for_cancel
163 .lock()
164 .unwrap_or_else(std::sync::PoisonError::into_inner)
165 .take();
166 completion.is_some_and(|completion| completion.cancel_unstarted())
167 });
168 self.state
169 .register_pending_task(task_id, Arc::clone(&cancel));
170 drop(submission_guard);
171
172 let completion_for_run = completion;
173 let state_for_run = Arc::clone(&self.state);
174 self.pool.spawn_fifo(move || {
175 if !state_for_run.start_pending_task(task_id, || true) {
176 return;
177 }
178 let completion = completion_for_run
179 .lock()
180 .unwrap_or_else(std::sync::PoisonError::into_inner)
181 .take();
182 if let Some(completion) = completion {
183 TaskRunner::new(task).run(completion);
184 }
185 state_for_run.on_task_completed();
186 });
187 Ok(handle)
188 }
189
190 fn submit_tracked_callable<C, R, E>(
192 &self,
193 task: C,
194 ) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
195 where
196 C: Callable<R, E> + Send + 'static,
197 R: Send + 'static,
198 E: Send + 'static,
199 {
200 let submission_guard = self.state.lock_submission();
201 if self.state.is_not_running() {
202 return Err(SubmissionError::Shutdown);
203 }
204 let task_id = self.state.next_task_id();
205 self.state.on_task_accepted();
206 let (handle, completion) = TaskEndpointPair::new().into_tracked_parts();
207 completion.accept();
208 let completion = Arc::new(Mutex::new(Some(completion)));
209 let completion_for_cancel = Arc::clone(&completion);
210 let cancel: PendingCancel = Arc::new(move || {
211 let completion = completion_for_cancel
212 .lock()
213 .unwrap_or_else(std::sync::PoisonError::into_inner)
214 .take();
215 completion.is_some_and(|completion| completion.cancel_unstarted())
216 });
217 self.state
218 .register_pending_task(task_id, Arc::clone(&cancel));
219 drop(submission_guard);
220
221 let completion_for_run = completion;
222 let state_for_run = Arc::clone(&self.state);
223 self.pool.spawn_fifo(move || {
224 if !state_for_run.start_pending_task(task_id, || true) {
225 return;
226 }
227 let completion = completion_for_run
228 .lock()
229 .unwrap_or_else(std::sync::PoisonError::into_inner)
230 .take();
231 if let Some(completion) = completion {
232 TaskRunner::new(task).run(completion);
233 }
234 state_for_run.on_task_completed();
235 });
236 Ok(RayonTaskHandle::new(
237 handle,
238 task_id,
239 Arc::clone(&self.state),
240 cancel,
241 ))
242 }
243
244 fn shutdown(&self) {
248 let _guard = self.state.lock_submission();
249 self.state.shutdown();
250 self.state.notify_if_terminated();
251 }
252
253 fn stop(&self) -> StopReport {
265 let _guard = self.state.lock_submission();
266 self.state.stop();
267 let (queued, running, pending) = self.state.drain_pending_tasks_for_shutdown();
268 drop(_guard);
269
270 let cancelled = self.state.cancel_drained_pending_tasks(pending);
271 StopReport::new(queued, running, cancelled)
272 }
273
274 fn lifecycle(&self) -> ExecutorServiceLifecycle {
276 self.state.lifecycle()
277 }
278
279 fn is_not_running(&self) -> bool {
281 self.state.is_not_running()
282 }
283
284 fn is_terminated(&self) -> bool {
286 self.lifecycle() == ExecutorServiceLifecycle::Terminated
287 }
288
289 fn wait_termination(&self) {
291 while !self.is_terminated() {
292 thread::sleep(Duration::from_millis(1));
293 }
294 }
295}