Skip to main content

qubit_rayon_executor/
rayon_executor_service.rs

1/*******************************************************************************
2 *
3 *    Copyright (c) 2025 - 2026 Haixing Hu.
4 *
5 *    SPDX-License-Identifier: Apache-2.0
6 *
7 *    Licensed under the Apache License, Version 2.0.
8 *
9 ******************************************************************************/
10use std::sync::{
11    Arc,
12    Mutex,
13};
14
15use qubit_function::{
16    Callable,
17    Runnable,
18};
19use rayon::ThreadPool as RayonThreadPool;
20
21use qubit_executor::{
22    TaskHandle,
23    task::spi::{
24        TaskEndpointPair,
25        TaskRunner,
26        TaskSlot,
27    },
28};
29
30use qubit_executor::service::{
31    ExecutorService,
32    ExecutorServiceLifecycle,
33    StopReport,
34    SubmissionError,
35};
36
37use crate::{
38    pending_cancel::PendingCancel,
39    rayon_executor_service_build_error::RayonExecutorServiceBuildError,
40    rayon_executor_service_builder::RayonExecutorServiceBuilder,
41    rayon_executor_service_state::RayonExecutorServiceState,
42    rayon_task_handle::RayonTaskHandle,
43};
44
45/// Rayon-backed executor service for CPU-bound synchronous tasks.
46///
47/// Accepted tasks are executed on a dedicated Rayon thread pool. The service
48/// preserves the crate's `ExecutorService` lifecycle semantics and task-handle
49/// APIs while delegating scheduling to Rayon.
50#[derive(Clone)]
51pub struct RayonExecutorService {
52    /// Rayon thread pool used to execute accepted tasks.
53    pub(crate) pool: Arc<RayonThreadPool>,
54    /// Shared lifecycle and cancellation state.
55    pub(crate) state: Arc<RayonExecutorServiceState>,
56}
57
58impl RayonExecutorService {
59    /// Creates a Rayon executor service with default builder settings.
60    ///
61    /// # Returns
62    ///
63    /// `Ok(RayonExecutorService)` if the default Rayon thread pool can be
64    /// built.
65    ///
66    /// # Errors
67    ///
68    /// Returns [`RayonExecutorServiceBuildError`] if the default builder
69    /// configuration is rejected.
70    #[inline]
71    pub fn new() -> Result<Self, RayonExecutorServiceBuildError> {
72        Self::builder().build()
73    }
74
75    /// Creates a builder for configuring a Rayon executor service.
76    ///
77    /// # Returns
78    ///
79    /// A builder configured with CPU-parallelism defaults.
80    #[inline]
81    pub fn builder() -> RayonExecutorServiceBuilder {
82        RayonExecutorServiceBuilder::default()
83    }
84
85    /// Accepts a callable, schedules it on the Rayon pool, and returns its handle data.
86    ///
87    /// # Parameters
88    ///
89    /// * `task` - Callable to execute on a Rayon worker.
90    /// * `split` - Function that splits a task endpoint pair into the caller
91    ///   handle and runner slot required by the chosen handle type.
92    ///
93    /// # Returns
94    ///
95    /// The caller-facing handle, stable task identifier, and pending
96    /// cancellation hook.
97    ///
98    /// # Errors
99    ///
100    /// Returns [`SubmissionError::Shutdown`] if shutdown or stop has already
101    /// been requested before the task is accepted.
102    fn submit_callable_with<C, R, E, H, F>(
103        &self,
104        task: C,
105        split: F,
106    ) -> Result<(H, usize, PendingCancel), SubmissionError>
107    where
108        C: Callable<R, E> + Send + 'static,
109        R: Send + 'static,
110        E: Send + 'static,
111        F: FnOnce(TaskEndpointPair<R, E>) -> (H, TaskSlot<R, E>),
112    {
113        let submission_guard = self.state.lock_submission();
114        if self.state.is_not_running() {
115            return Err(SubmissionError::Shutdown);
116        }
117        let task_id = self.state.next_task_id();
118        self.state.on_task_accepted();
119        let (handle, completion) = split(TaskEndpointPair::new());
120        completion.accept();
121        let completion = Arc::new(Mutex::new(Some(completion)));
122        let completion_for_cancel = Arc::clone(&completion);
123        let cancel: PendingCancel = Arc::new(move || {
124            let completion = completion_for_cancel
125                .lock()
126                .unwrap_or_else(std::sync::PoisonError::into_inner)
127                .take();
128            completion.is_some_and(|completion| completion.cancel_unstarted())
129        });
130        self.state
131            .register_pending_task(task_id, Arc::clone(&cancel));
132        drop(submission_guard);
133
134        let completion_for_run = completion;
135        let state_for_run = Arc::clone(&self.state);
136        self.pool.spawn_fifo(move || {
137            let mut running_completion = None;
138            if !state_for_run.start_pending_task(task_id, || {
139                let mut completion = completion_for_run
140                    .lock()
141                    .unwrap_or_else(std::sync::PoisonError::into_inner);
142                let Some(task_completion) = completion.take() else {
143                    return false;
144                };
145                match task_completion.try_start() {
146                    Ok(running) => {
147                        running_completion = Some(running);
148                        true
149                    }
150                    Err(task_completion) => {
151                        *completion = Some(task_completion);
152                        false
153                    }
154                }
155            }) {
156                return;
157            }
158            let running_completion =
159                running_completion.expect("claimed pending task should own a running slot");
160            TaskRunner::new(task).run_started(running_completion);
161            state_for_run.on_task_completed();
162        });
163        Ok((handle, task_id, cancel))
164    }
165}
166
167impl ExecutorService for RayonExecutorService {
168    type ResultHandle<R, E>
169        = TaskHandle<R, E>
170    where
171        R: Send + 'static,
172        E: Send + 'static;
173
174    type TrackedHandle<R, E>
175        = RayonTaskHandle<R, E>
176    where
177        R: Send + 'static,
178        E: Send + 'static;
179
180    /// Accepts a runnable and schedules it on the Rayon thread pool.
181    fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
182    where
183        T: Runnable<E> + Send + 'static,
184        E: Send + 'static,
185    {
186        let submission_guard = self.state.lock_submission();
187        if self.state.is_not_running() {
188            return Err(SubmissionError::Shutdown);
189        }
190        let task_id = self.state.next_task_id();
191        self.state.on_task_accepted();
192        let cancel: PendingCancel = Arc::new(|| true);
193        self.state
194            .register_pending_task(task_id, Arc::clone(&cancel));
195        drop(submission_guard);
196
197        let state_for_run = Arc::clone(&self.state);
198        self.pool.spawn_fifo(move || {
199            if !state_for_run.start_pending_task(task_id, || true) {
200                return;
201            }
202            let mut task = task;
203            let _ignored = TaskRunner::new(move || task.run()).call::<(), E>();
204            state_for_run.on_task_completed();
205        });
206        Ok(())
207    }
208
209    /// Accepts a callable and schedules it on the Rayon thread pool.
210    ///
211    /// # Parameters
212    ///
213    /// * `task` - Callable to execute on a Rayon worker.
214    ///
215    /// # Returns
216    ///
217    /// A [`TaskHandle`] for the accepted task.
218    ///
219    /// # Errors
220    ///
221    /// Returns [`SubmissionError::Shutdown`] if shutdown has already been
222    /// requested before the task is accepted.
223    fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
224    where
225        C: Callable<R, E> + Send + 'static,
226        R: Send + 'static,
227        E: Send + 'static,
228    {
229        let (handle, _, _) = self.submit_callable_with(task, TaskEndpointPair::into_parts)?;
230        Ok(handle)
231    }
232
233    /// Accepts a callable and schedules it with a tracked handle.
234    fn submit_tracked_callable<C, R, E>(
235        &self,
236        task: C,
237    ) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
238    where
239        C: Callable<R, E> + Send + 'static,
240        R: Send + 'static,
241        E: Send + 'static,
242    {
243        let (handle, task_id, cancel) =
244            self.submit_callable_with(task, TaskEndpointPair::into_tracked_parts)?;
245        Ok(RayonTaskHandle::new(
246            handle,
247            task_id,
248            Arc::clone(&self.state),
249            cancel,
250        ))
251    }
252
253    /// Stops accepting new tasks.
254    ///
255    /// Already accepted Rayon tasks are allowed to finish normally.
256    fn shutdown(&self) {
257        let _guard = self.state.lock_submission();
258        self.state.shutdown();
259        self.state.notify_if_terminated();
260    }
261
262    /// Stops accepting new tasks and cancels tasks that have not started yet.
263    ///
264    /// Running Rayon tasks cannot be preempted. Cancellation therefore applies
265    /// only to tasks that are still pending when the cancellation hook wins the
266    /// race against task start.
267    ///
268    /// # Returns
269    ///
270    /// A count-based report describing the pending and running work observed at
271    /// the time of the stop request, plus the number of pending tasks for
272    /// which cancellation succeeded.
273    fn stop(&self) -> StopReport {
274        let _guard = self.state.lock_submission();
275        self.state.stop();
276        self.state.cancel_pending_tasks_for_stop()
277    }
278
279    /// Returns the current lifecycle state.
280    fn lifecycle(&self) -> ExecutorServiceLifecycle {
281        self.state.lifecycle()
282    }
283
284    /// Returns whether shutdown has been requested.
285    fn is_not_running(&self) -> bool {
286        self.state.is_not_running()
287    }
288
289    /// Returns whether shutdown was requested and no accepted tasks remain.
290    fn is_terminated(&self) -> bool {
291        self.lifecycle() == ExecutorServiceLifecycle::Terminated
292    }
293
294    /// Blocks until the service has terminated.
295    fn wait_termination(&self) {
296        self.state.wait_for_termination();
297    }
298}