1use std::sync::{
11 Arc,
12 Mutex,
13};
14
15use qubit_function::{
16 Callable,
17 Runnable,
18};
19use rayon::ThreadPool as RayonThreadPool;
20
21use qubit_executor::{
22 TaskHandle,
23 task::spi::{
24 TaskEndpointPair,
25 TaskRunner,
26 TaskSlot,
27 },
28};
29
30use qubit_executor::service::{
31 ExecutorService,
32 ExecutorServiceLifecycle,
33 StopReport,
34 SubmissionError,
35};
36
37use crate::{
38 pending_cancel::PendingCancel,
39 rayon_executor_service_build_error::RayonExecutorServiceBuildError,
40 rayon_executor_service_builder::RayonExecutorServiceBuilder,
41 rayon_executor_service_state::RayonExecutorServiceState,
42 rayon_task_handle::RayonTaskHandle,
43};
44
45#[derive(Clone)]
51pub struct RayonExecutorService {
52 pub(crate) pool: Arc<RayonThreadPool>,
54 pub(crate) state: Arc<RayonExecutorServiceState>,
56}
57
58impl RayonExecutorService {
59 #[inline]
71 pub fn new() -> Result<Self, RayonExecutorServiceBuildError> {
72 Self::builder().build()
73 }
74
75 #[inline]
81 pub fn builder() -> RayonExecutorServiceBuilder {
82 RayonExecutorServiceBuilder::default()
83 }
84
85 fn submit_callable_with<C, R, E, H, F>(
103 &self,
104 task: C,
105 split: F,
106 ) -> Result<(H, usize, PendingCancel), SubmissionError>
107 where
108 C: Callable<R, E> + Send + 'static,
109 R: Send + 'static,
110 E: Send + 'static,
111 F: FnOnce(TaskEndpointPair<R, E>) -> (H, TaskSlot<R, E>),
112 {
113 let submission_guard = self.state.lock_submission();
114 if self.state.is_not_running() {
115 return Err(SubmissionError::Shutdown);
116 }
117 let task_id = self.state.next_task_id();
118 self.state.on_task_accepted();
119 let (handle, completion) = split(TaskEndpointPair::new());
120 completion.accept();
121 let completion = Arc::new(Mutex::new(Some(completion)));
122 let completion_for_cancel = Arc::clone(&completion);
123 let cancel: PendingCancel = Arc::new(move || {
124 let completion = completion_for_cancel
125 .lock()
126 .unwrap_or_else(std::sync::PoisonError::into_inner)
127 .take();
128 completion.is_some_and(|completion| completion.cancel_unstarted())
129 });
130 self.state
131 .register_pending_task(task_id, Arc::clone(&cancel));
132 drop(submission_guard);
133
134 let completion_for_run = completion;
135 let state_for_run = Arc::clone(&self.state);
136 self.pool.spawn_fifo(move || {
137 if !state_for_run.start_pending_task(task_id, || true) {
138 return;
139 }
140 let completion = completion_for_run
141 .lock()
142 .unwrap_or_else(std::sync::PoisonError::into_inner)
143 .take();
144 if let Some(completion) = completion {
145 TaskRunner::new(task).run(completion);
146 }
147 state_for_run.on_task_completed();
148 });
149 Ok((handle, task_id, cancel))
150 }
151}
152
153impl ExecutorService for RayonExecutorService {
154 type ResultHandle<R, E>
155 = TaskHandle<R, E>
156 where
157 R: Send + 'static,
158 E: Send + 'static;
159
160 type TrackedHandle<R, E>
161 = RayonTaskHandle<R, E>
162 where
163 R: Send + 'static,
164 E: Send + 'static;
165
166 fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
168 where
169 T: Runnable<E> + Send + 'static,
170 E: Send + 'static,
171 {
172 let submission_guard = self.state.lock_submission();
173 if self.state.is_not_running() {
174 return Err(SubmissionError::Shutdown);
175 }
176 let task_id = self.state.next_task_id();
177 self.state.on_task_accepted();
178 let cancel: PendingCancel = Arc::new(|| true);
179 self.state
180 .register_pending_task(task_id, Arc::clone(&cancel));
181 drop(submission_guard);
182
183 let state_for_run = Arc::clone(&self.state);
184 self.pool.spawn_fifo(move || {
185 if !state_for_run.start_pending_task(task_id, || true) {
186 return;
187 }
188 let mut task = task;
189 let _ignored = TaskRunner::new(move || task.run()).call::<(), E>();
190 state_for_run.on_task_completed();
191 });
192 Ok(())
193 }
194
195 fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
210 where
211 C: Callable<R, E> + Send + 'static,
212 R: Send + 'static,
213 E: Send + 'static,
214 {
215 let (handle, _, _) = self.submit_callable_with(task, TaskEndpointPair::into_parts)?;
216 Ok(handle)
217 }
218
219 fn submit_tracked_callable<C, R, E>(
221 &self,
222 task: C,
223 ) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
224 where
225 C: Callable<R, E> + Send + 'static,
226 R: Send + 'static,
227 E: Send + 'static,
228 {
229 let (handle, task_id, cancel) =
230 self.submit_callable_with(task, TaskEndpointPair::into_tracked_parts)?;
231 Ok(RayonTaskHandle::new(
232 handle,
233 task_id,
234 Arc::clone(&self.state),
235 cancel,
236 ))
237 }
238
239 fn shutdown(&self) {
243 let _guard = self.state.lock_submission();
244 self.state.shutdown();
245 self.state.notify_if_terminated();
246 }
247
248 fn stop(&self) -> StopReport {
260 let _guard = self.state.lock_submission();
261 self.state.stop();
262 self.state.cancel_pending_tasks_for_stop()
263 }
264
265 fn lifecycle(&self) -> ExecutorServiceLifecycle {
267 self.state.lifecycle()
268 }
269
270 fn is_not_running(&self) -> bool {
272 self.state.is_not_running()
273 }
274
275 fn is_terminated(&self) -> bool {
277 self.lifecycle() == ExecutorServiceLifecycle::Terminated
278 }
279
280 fn wait_termination(&self) {
282 self.state.wait_for_termination();
283 }
284}