qubit_tokio_executor/
tokio_executor_service.rs1use std::{
11 future::Future,
12 pin::Pin,
13 sync::{
14 Arc,
15 Mutex,
16 },
17};
18
19use qubit_function::{
20 Callable,
21 Runnable,
22};
23
24use qubit_executor::TaskHandle;
25use qubit_executor::task::spi::{
26 TaskEndpointPair,
27 TaskRunner,
28};
29
30use crate::TokioBlockingTaskHandle;
31use crate::tokio_executor_service_state::TokioExecutorServiceState;
32use crate::tokio_service_task_guard::TokioServiceTaskGuard;
33use qubit_executor::service::{
34 ExecutorService,
35 ExecutorServiceLifecycle,
36 StopReport,
37 SubmissionError,
38};
39
40#[derive(Default, Clone)]
45pub struct TokioExecutorService {
46 state: Arc<TokioExecutorServiceState>,
48}
49
50pub type TokioBlockingExecutorService = TokioExecutorService;
52
53impl TokioExecutorService {
54 #[inline]
60 pub fn new() -> Self {
61 Self::default()
62 }
63}
64
65impl ExecutorService for TokioExecutorService {
66 type ResultHandle<R, E>
67 = TaskHandle<R, E>
68 where
69 R: Send + 'static,
70 E: Send + 'static;
71
72 type TrackedHandle<R, E>
73 = TokioBlockingTaskHandle<R, E>
74 where
75 R: Send + 'static,
76 E: Send + 'static;
77
78 fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
93 where
94 T: Runnable<E> + Send + 'static,
95 E: Send + 'static,
96 {
97 let submission_guard = self.state.lock_submission();
98 if self.state.is_not_running() {
99 return Err(SubmissionError::Shutdown);
100 }
101 self.state.accept_task();
102
103 let marker = Arc::new(());
104 let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
105 let abort_queued_task = guard.finish_queued_once_callback();
106 let handle = tokio::task::spawn_blocking(move || {
107 let guard = guard;
108 if !guard.mark_started() {
109 return;
110 }
111 let mut task = task;
112 let runner = TaskRunner::new(move || task.run());
113 let _ = runner.call::<(), E>();
114 });
115 self.state
116 .register_abort_handle(marker, handle.abort_handle(), abort_queued_task);
117 drop(submission_guard);
118 Ok(())
119 }
120
121 fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
136 where
137 C: Callable<R, E> + Send + 'static,
138 R: Send + 'static,
139 E: Send + 'static,
140 {
141 let submission_guard = self.state.lock_submission();
142 if self.state.is_not_running() {
143 return Err(SubmissionError::Shutdown);
144 }
145 self.state.accept_task();
146
147 let (handle, completion) = TaskEndpointPair::new().into_parts();
148 completion.accept();
149 let completion = Arc::new(Mutex::new(Some(completion)));
150 let abort_completion = Arc::clone(&completion);
151 let marker = Arc::new(());
152 let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
153 let abort_queued_task = guard.finish_queued_once_callback();
154 let join_handle = tokio::task::spawn_blocking(move || {
155 let guard = guard;
156 if !guard.mark_started() {
157 return;
158 }
159 let completion = completion
160 .lock()
161 .unwrap_or_else(std::sync::PoisonError::into_inner)
162 .take();
163 if let Some(completion) = completion {
164 TaskRunner::new(task).run(completion);
165 }
166 });
167 self.state
168 .register_abort_handle(marker, join_handle.abort_handle(), move || {
169 let completion = abort_completion
170 .lock()
171 .unwrap_or_else(std::sync::PoisonError::into_inner)
172 .take();
173 if let Some(completion) = completion {
174 let _cancelled = completion.cancel_unstarted();
175 }
176 abort_queued_task();
177 });
178 drop(submission_guard);
179 Ok(handle)
180 }
181
182 fn submit_tracked_callable<C, R, E>(
197 &self,
198 task: C,
199 ) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
200 where
201 C: Callable<R, E> + Send + 'static,
202 R: Send + 'static,
203 E: Send + 'static,
204 {
205 let submission_guard = self.state.lock_submission();
206 if self.state.is_not_running() {
207 return Err(SubmissionError::Shutdown);
208 }
209 self.state.accept_task();
210
211 let (handle, completion) = TaskEndpointPair::new().into_tracked_parts();
212 completion.accept();
213 let completion = Arc::new(Mutex::new(Some(completion)));
214 let abort_completion = Arc::clone(&completion);
215 let marker = Arc::new(());
216 let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
217 let abort_queued_task = guard.finish_queued_once_callback();
218 let cancel_queued_task = guard.finish_queued_callback();
219 let join_handle = tokio::task::spawn_blocking(move || {
220 let guard = guard;
221 if !guard.mark_started() {
222 return;
223 }
224 let completion = completion
225 .lock()
226 .unwrap_or_else(std::sync::PoisonError::into_inner)
227 .take();
228 if let Some(completion) = completion {
229 TaskRunner::new(task).run(completion);
230 }
231 });
232 let abort_handle = join_handle.abort_handle();
233 self.state
234 .register_abort_handle(marker, abort_handle.clone(), move || {
235 let completion = abort_completion
236 .lock()
237 .unwrap_or_else(std::sync::PoisonError::into_inner)
238 .take();
239 if let Some(completion) = completion {
240 let _cancelled = completion.cancel_unstarted();
241 }
242 abort_queued_task();
243 });
244 drop(submission_guard);
245 Ok(TokioBlockingTaskHandle::new(
246 handle,
247 abort_handle,
248 cancel_queued_task,
249 ))
250 }
251
252 fn shutdown(&self) {
257 let _guard = self.state.lock_submission();
258 self.state.shutdown();
259 self.state.notify_if_terminated();
260 }
261
262 fn stop(&self) -> StopReport {
272 let _guard = self.state.lock_submission();
273 self.state.stop();
274 let (queued_count, running_count) = self.state.task_count_snapshot();
275 let cancellation_count = self.state.abort_tracked_tasks();
276 self.state.notify_if_terminated();
277 StopReport::new(queued_count, running_count, cancellation_count)
278 }
279
280 fn lifecycle(&self) -> ExecutorServiceLifecycle {
282 self.state.lifecycle()
283 }
284
285 fn is_not_running(&self) -> bool {
287 self.state.is_not_running()
288 }
289
290 fn is_terminated(&self) -> bool {
292 self.lifecycle() == ExecutorServiceLifecycle::Terminated
293 }
294
295 fn wait_termination(&self) {
297 self.state.wait_termination();
298 }
299}
300
301impl TokioExecutorService {
302 pub fn await_termination(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
309 Box::pin(async move {
310 let notified = self.state.terminated_notify.notified();
311 tokio::pin!(notified);
312 loop {
313 notified.as_mut().enable();
314 if self.is_terminated() {
315 return;
316 }
317 notified.as_mut().await;
318 notified.set(self.state.terminated_notify.notified());
319 }
320 })
321 }
322}