Skip to main content

qubit_tokio_executor/
tokio_executor_service.rs

1/*******************************************************************************
2 *
3 *    Copyright (c) 2025 - 2026 Haixing Hu.
4 *
5 *    SPDX-License-Identifier: Apache-2.0
6 *
7 *    Licensed under the Apache License, Version 2.0.
8 *
9 ******************************************************************************/
10use std::{
11    future::Future,
12    pin::Pin,
13    sync::{
14        Arc,
15        Mutex,
16    },
17};
18
19use qubit_function::{
20    Callable,
21    Runnable,
22};
23
24use qubit_executor::TaskHandle;
25use qubit_executor::task::spi::{
26    TaskEndpointPair,
27    TaskRunner,
28};
29
30use crate::TokioBlockingTaskHandle;
31use crate::tokio_executor_service_state::TokioExecutorServiceState;
32use crate::tokio_service_task_guard::TokioServiceTaskGuard;
33use qubit_executor::service::{
34    ExecutorService,
35    ExecutorServiceLifecycle,
36    StopReport,
37    SubmissionError,
38};
39
40/// Tokio-backed service for submitted blocking tasks.
41///
42/// The service accepts fallible [`Runnable`](qubit_function::Runnable) and
43/// [`Callable`] tasks and runs them through Tokio's blocking task pool.
44#[derive(Default, Clone)]
45pub struct TokioExecutorService {
46    /// Shared service state used by all clones of this service.
47    state: Arc<TokioExecutorServiceState>,
48}
49
50/// Tokio-backed blocking executor service routed through `spawn_blocking`.
51pub type TokioBlockingExecutorService = TokioExecutorService;
52
53impl TokioExecutorService {
54    /// Creates a new service instance.
55    ///
56    /// # Returns
57    ///
58    /// A Tokio-backed executor service.
59    #[inline]
60    pub fn new() -> Self {
61        Self::default()
62    }
63}
64
65impl ExecutorService for TokioExecutorService {
66    type ResultHandle<R, E>
67        = TaskHandle<R, E>
68    where
69        R: Send + 'static,
70        E: Send + 'static;
71
72    type TrackedHandle<R, E>
73        = TokioBlockingTaskHandle<R, E>
74    where
75        R: Send + 'static,
76        E: Send + 'static;
77
78    /// Accepts a runnable and runs it through Tokio.
79    ///
80    /// # Parameters
81    ///
82    /// * `task` - Runnable to execute on Tokio's blocking task pool.
83    ///
84    /// # Returns
85    ///
86    /// `Ok(())` if the task was accepted.
87    ///
88    /// # Errors
89    ///
90    /// Returns [`SubmissionError::Shutdown`] if shutdown has already been
91    /// requested before the task is accepted.
92    fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
93    where
94        T: Runnable<E> + Send + 'static,
95        E: Send + 'static,
96    {
97        let submission_guard = self.state.lock_submission();
98        if self.state.is_not_running() {
99            return Err(SubmissionError::Shutdown);
100        }
101        self.state.accept_task();
102
103        let marker = Arc::new(());
104        let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
105        let abort_queued_task = guard.finish_queued_once_callback();
106        let handle = tokio::task::spawn_blocking(move || {
107            let guard = guard;
108            if !guard.mark_started() {
109                return;
110            }
111            let mut task = task;
112            let runner = TaskRunner::new(move || task.run());
113            let _ = runner.call::<(), E>();
114        });
115        self.state
116            .register_abort_handle(marker, handle.abort_handle(), abort_queued_task);
117        drop(submission_guard);
118        Ok(())
119    }
120
121    /// Accepts a callable and runs it through Tokio.
122    ///
123    /// # Parameters
124    ///
125    /// * `task` - Callable to execute on Tokio's blocking task pool.
126    ///
127    /// # Returns
128    ///
129    /// A [`TaskHandle`] for the accepted task.
130    ///
131    /// # Errors
132    ///
133    /// Returns [`SubmissionError::Shutdown`] if shutdown has already been
134    /// requested before the task is accepted.
135    fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
136    where
137        C: Callable<R, E> + Send + 'static,
138        R: Send + 'static,
139        E: Send + 'static,
140    {
141        let submission_guard = self.state.lock_submission();
142        if self.state.is_not_running() {
143            return Err(SubmissionError::Shutdown);
144        }
145        self.state.accept_task();
146
147        let (handle, completion) = TaskEndpointPair::new().into_parts();
148        completion.accept();
149        let completion = Arc::new(Mutex::new(Some(completion)));
150        let abort_completion = Arc::clone(&completion);
151        let marker = Arc::new(());
152        let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
153        let abort_queued_task = guard.finish_queued_once_callback();
154        let join_handle = tokio::task::spawn_blocking(move || {
155            let guard = guard;
156            if !guard.mark_started() {
157                return;
158            }
159            let completion = completion
160                .lock()
161                .unwrap_or_else(std::sync::PoisonError::into_inner)
162                .take();
163            if let Some(completion) = completion {
164                TaskRunner::new(task).run(completion);
165            }
166        });
167        self.state
168            .register_abort_handle(marker, join_handle.abort_handle(), move || {
169                let completion = abort_completion
170                    .lock()
171                    .unwrap_or_else(std::sync::PoisonError::into_inner)
172                    .take();
173                if let Some(completion) = completion {
174                    let _cancelled = completion.cancel_unstarted();
175                }
176                abort_queued_task();
177            });
178        drop(submission_guard);
179        Ok(handle)
180    }
181
182    /// Accepts a callable and returns an actively tracked handle.
183    ///
184    /// # Parameters
185    ///
186    /// * `task` - Callable to execute on Tokio's blocking task pool.
187    ///
188    /// # Returns
189    ///
190    /// A [`TokioBlockingTaskHandle`] for the accepted task.
191    ///
192    /// # Errors
193    ///
194    /// Returns [`SubmissionError::Shutdown`] if shutdown has already been
195    /// requested before the task is accepted.
196    fn submit_tracked_callable<C, R, E>(
197        &self,
198        task: C,
199    ) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
200    where
201        C: Callable<R, E> + Send + 'static,
202        R: Send + 'static,
203        E: Send + 'static,
204    {
205        let submission_guard = self.state.lock_submission();
206        if self.state.is_not_running() {
207            return Err(SubmissionError::Shutdown);
208        }
209        self.state.accept_task();
210
211        let (handle, completion) = TaskEndpointPair::new().into_tracked_parts();
212        completion.accept();
213        let completion = Arc::new(Mutex::new(Some(completion)));
214        let abort_completion = Arc::clone(&completion);
215        let marker = Arc::new(());
216        let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
217        let abort_queued_task = guard.finish_queued_once_callback();
218        let cancel_queued_task = guard.finish_queued_callback();
219        let join_handle = tokio::task::spawn_blocking(move || {
220            let guard = guard;
221            if !guard.mark_started() {
222                return;
223            }
224            let completion = completion
225                .lock()
226                .unwrap_or_else(std::sync::PoisonError::into_inner)
227                .take();
228            if let Some(completion) = completion {
229                TaskRunner::new(task).run(completion);
230            }
231        });
232        let abort_handle = join_handle.abort_handle();
233        self.state
234            .register_abort_handle(marker, abort_handle.clone(), move || {
235                let completion = abort_completion
236                    .lock()
237                    .unwrap_or_else(std::sync::PoisonError::into_inner)
238                    .take();
239                if let Some(completion) = completion {
240                    let _cancelled = completion.cancel_unstarted();
241                }
242                abort_queued_task();
243            });
244        drop(submission_guard);
245        Ok(TokioBlockingTaskHandle::new(
246            handle,
247            abort_handle,
248            cancel_queued_task,
249        ))
250    }
251
252    /// Stops accepting new tasks.
253    ///
254    /// Already accepted tasks are allowed to finish unless they are cancelled
255    /// before their blocking closure starts.
256    fn shutdown(&self) {
257        let _guard = self.state.lock_submission();
258        self.state.shutdown();
259        self.state.notify_if_terminated();
260    }
261
262    /// Stops accepting new tasks and requests abort for tracked Tokio tasks.
263    ///
264    /// Tokio cannot abort blocking tasks that have already started. Such tasks
265    /// continue running and keep the service active until their closure returns.
266    ///
267    /// # Returns
268    ///
269    /// A report with queued and running blocking task counts observed when
270    /// stop was requested, plus the number of Tokio abort handles signalled.
271    fn stop(&self) -> StopReport {
272        let _guard = self.state.lock_submission();
273        self.state.stop();
274        let (queued_count, running_count) = self.state.task_count_snapshot();
275        let cancellation_count = self.state.abort_tracked_tasks();
276        self.state.notify_if_terminated();
277        StopReport::new(queued_count, running_count, cancellation_count)
278    }
279
280    /// Returns the current lifecycle state.
281    fn lifecycle(&self) -> ExecutorServiceLifecycle {
282        self.state.lifecycle()
283    }
284
285    /// Returns whether shutdown has been requested.
286    fn is_not_running(&self) -> bool {
287        self.state.is_not_running()
288    }
289
290    /// Returns whether shutdown was requested and all tasks are finished.
291    fn is_terminated(&self) -> bool {
292        self.lifecycle() == ExecutorServiceLifecycle::Terminated
293    }
294
295    /// Blocks until the service has terminated.
296    fn wait_termination(&self) {
297        self.state.wait_termination();
298    }
299}
300
301impl TokioExecutorService {
302    /// Waits asynchronously until the service has terminated.
303    ///
304    /// # Returns
305    ///
306    /// A future that resolves after shutdown or stop has been requested and all
307    /// accepted blocking tasks have finished or been aborted before start.
308    pub fn await_termination(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
309        Box::pin(async move {
310            let notified = self.state.terminated_notify.notified();
311            tokio::pin!(notified);
312            loop {
313                notified.as_mut().enable();
314                if self.is_terminated() {
315                    return;
316                }
317                notified.as_mut().await;
318                notified.set(self.state.terminated_notify.notified());
319            }
320        })
321    }
322}