Skip to main content

qubit_rayon_executor/
rayon_executor_service.rs

1/*******************************************************************************
2 *
3 *    Copyright (c) 2025 - 2026 Haixing Hu.
4 *
5 *    SPDX-License-Identifier: Apache-2.0
6 *
7 *    Licensed under the Apache License, Version 2.0.
8 *
9 ******************************************************************************/
10use std::{
11    sync::{
12        Arc,
13        Mutex,
14    },
15    thread,
16    time::Duration,
17};
18
19use qubit_function::{
20    Callable,
21    Runnable,
22};
23use rayon::ThreadPool as RayonThreadPool;
24
25use qubit_executor::{
26    TaskHandle,
27    task::spi::{
28        TaskEndpointPair,
29        TaskRunner,
30    },
31};
32
33use qubit_executor::service::{
34    ExecutorService,
35    ExecutorServiceLifecycle,
36    StopReport,
37    SubmissionError,
38};
39
40use crate::{
41    pending_cancel::PendingCancel,
42    rayon_executor_service_build_error::RayonExecutorServiceBuildError,
43    rayon_executor_service_builder::RayonExecutorServiceBuilder,
44    rayon_executor_service_state::RayonExecutorServiceState,
45    rayon_task_handle::RayonTaskHandle,
46};
47
48/// Rayon-backed executor service for CPU-bound synchronous tasks.
49///
50/// Accepted tasks are executed on a dedicated Rayon thread pool. The service
51/// preserves the crate's `ExecutorService` lifecycle semantics and task-handle
52/// APIs while delegating scheduling to Rayon.
53#[derive(Clone)]
54pub struct RayonExecutorService {
55    /// Rayon thread pool used to execute accepted tasks.
56    pub(crate) pool: Arc<RayonThreadPool>,
57    /// Shared lifecycle and cancellation state.
58    pub(crate) state: Arc<RayonExecutorServiceState>,
59}
60
61impl RayonExecutorService {
62    /// Creates a Rayon executor service with default builder settings.
63    ///
64    /// # Returns
65    ///
66    /// `Ok(RayonExecutorService)` if the default Rayon thread pool can be
67    /// built.
68    ///
69    /// # Errors
70    ///
71    /// Returns [`RayonExecutorServiceBuildError`] if the default builder
72    /// configuration is rejected.
73    #[inline]
74    pub fn new() -> Result<Self, RayonExecutorServiceBuildError> {
75        Self::builder().build()
76    }
77
78    /// Creates a builder for configuring a Rayon executor service.
79    ///
80    /// # Returns
81    ///
82    /// A builder configured with CPU-parallelism defaults.
83    #[inline]
84    pub fn builder() -> RayonExecutorServiceBuilder {
85        RayonExecutorServiceBuilder::default()
86    }
87}
88
89impl ExecutorService for RayonExecutorService {
90    type ResultHandle<R, E>
91        = TaskHandle<R, E>
92    where
93        R: Send + 'static,
94        E: Send + 'static;
95
96    type TrackedHandle<R, E>
97        = RayonTaskHandle<R, E>
98    where
99        R: Send + 'static,
100        E: Send + 'static;
101
102    /// Accepts a runnable and schedules it on the Rayon thread pool.
103    fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
104    where
105        T: Runnable<E> + Send + 'static,
106        E: Send + 'static,
107    {
108        let submission_guard = self.state.lock_submission();
109        if self.state.is_not_running() {
110            return Err(SubmissionError::Shutdown);
111        }
112        let task_id = self.state.next_task_id();
113        self.state.on_task_accepted();
114        let cancel: PendingCancel = Arc::new(|| true);
115        self.state
116            .register_pending_task(task_id, Arc::clone(&cancel));
117        drop(submission_guard);
118
119        let state_for_run = Arc::clone(&self.state);
120        self.pool.spawn_fifo(move || {
121            if !state_for_run.start_pending_task(task_id, || true) {
122                return;
123            }
124            let mut task = task;
125            let _ignored = TaskRunner::new(move || task.run()).call::<(), E>();
126            state_for_run.on_task_completed();
127        });
128        Ok(())
129    }
130
131    /// Accepts a callable and schedules it on the Rayon thread pool.
132    ///
133    /// # Parameters
134    ///
135    /// * `task` - Callable to execute on a Rayon worker.
136    ///
137    /// # Returns
138    ///
139    /// A [`RayonTaskHandle`] for the accepted task.
140    ///
141    /// # Errors
142    ///
143    /// Returns [`SubmissionError::Shutdown`] if shutdown has already been
144    /// requested before the task is accepted.
145    fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
146    where
147        C: Callable<R, E> + Send + 'static,
148        R: Send + 'static,
149        E: Send + 'static,
150    {
151        let submission_guard = self.state.lock_submission();
152        if self.state.is_not_running() {
153            return Err(SubmissionError::Shutdown);
154        }
155        let task_id = self.state.next_task_id();
156        self.state.on_task_accepted();
157        let (handle, completion) = TaskEndpointPair::new().into_parts();
158        completion.accept();
159        let completion = Arc::new(Mutex::new(Some(completion)));
160        let completion_for_cancel = Arc::clone(&completion);
161        let cancel: PendingCancel = Arc::new(move || {
162            let completion = completion_for_cancel
163                .lock()
164                .unwrap_or_else(std::sync::PoisonError::into_inner)
165                .take();
166            completion.is_some_and(|completion| completion.cancel_unstarted())
167        });
168        self.state
169            .register_pending_task(task_id, Arc::clone(&cancel));
170        drop(submission_guard);
171
172        let completion_for_run = completion;
173        let state_for_run = Arc::clone(&self.state);
174        self.pool.spawn_fifo(move || {
175            if !state_for_run.start_pending_task(task_id, || true) {
176                return;
177            }
178            let completion = completion_for_run
179                .lock()
180                .unwrap_or_else(std::sync::PoisonError::into_inner)
181                .take();
182            if let Some(completion) = completion {
183                TaskRunner::new(task).run(completion);
184            }
185            state_for_run.on_task_completed();
186        });
187        Ok(handle)
188    }
189
190    /// Accepts a callable and schedules it with a tracked handle.
191    fn submit_tracked_callable<C, R, E>(
192        &self,
193        task: C,
194    ) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
195    where
196        C: Callable<R, E> + Send + 'static,
197        R: Send + 'static,
198        E: Send + 'static,
199    {
200        let submission_guard = self.state.lock_submission();
201        if self.state.is_not_running() {
202            return Err(SubmissionError::Shutdown);
203        }
204        let task_id = self.state.next_task_id();
205        self.state.on_task_accepted();
206        let (handle, completion) = TaskEndpointPair::new().into_tracked_parts();
207        completion.accept();
208        let completion = Arc::new(Mutex::new(Some(completion)));
209        let completion_for_cancel = Arc::clone(&completion);
210        let cancel: PendingCancel = Arc::new(move || {
211            let completion = completion_for_cancel
212                .lock()
213                .unwrap_or_else(std::sync::PoisonError::into_inner)
214                .take();
215            completion.is_some_and(|completion| completion.cancel_unstarted())
216        });
217        self.state
218            .register_pending_task(task_id, Arc::clone(&cancel));
219        drop(submission_guard);
220
221        let completion_for_run = completion;
222        let state_for_run = Arc::clone(&self.state);
223        self.pool.spawn_fifo(move || {
224            if !state_for_run.start_pending_task(task_id, || true) {
225                return;
226            }
227            let completion = completion_for_run
228                .lock()
229                .unwrap_or_else(std::sync::PoisonError::into_inner)
230                .take();
231            if let Some(completion) = completion {
232                TaskRunner::new(task).run(completion);
233            }
234            state_for_run.on_task_completed();
235        });
236        Ok(RayonTaskHandle::new(
237            handle,
238            task_id,
239            Arc::clone(&self.state),
240            cancel,
241        ))
242    }
243
244    /// Stops accepting new tasks.
245    ///
246    /// Already accepted Rayon tasks are allowed to finish normally.
247    fn shutdown(&self) {
248        let _guard = self.state.lock_submission();
249        self.state.shutdown();
250        self.state.notify_if_terminated();
251    }
252
253    /// Stops accepting new tasks and cancels tasks that have not started yet.
254    ///
255    /// Running Rayon tasks cannot be preempted. Cancellation therefore applies
256    /// only to tasks that are still pending when the cancellation hook wins the
257    /// race against task start.
258    ///
259    /// # Returns
260    ///
261    /// A count-based report describing the pending and running work observed at
262    /// the time of the shutdown request, plus the number of pending tasks for
263    /// which cancellation succeeded.
264    fn stop(&self) -> StopReport {
265        let _guard = self.state.lock_submission();
266        self.state.stop();
267        let (queued, running, pending) = self.state.drain_pending_tasks_for_shutdown();
268        drop(_guard);
269
270        let cancelled = self.state.cancel_drained_pending_tasks(pending);
271        StopReport::new(queued, running, cancelled)
272    }
273
274    /// Returns the current lifecycle state.
275    fn lifecycle(&self) -> ExecutorServiceLifecycle {
276        self.state.lifecycle()
277    }
278
279    /// Returns whether shutdown has been requested.
280    fn is_not_running(&self) -> bool {
281        self.state.is_not_running()
282    }
283
284    /// Returns whether shutdown was requested and no accepted tasks remain.
285    fn is_terminated(&self) -> bool {
286        self.lifecycle() == ExecutorServiceLifecycle::Terminated
287    }
288
289    /// Blocks until the service has terminated.
290    fn wait_termination(&self) {
291        while !self.is_terminated() {
292            thread::sleep(Duration::from_millis(1));
293        }
294    }
295}