qubit_tokio_executor/
tokio_executor_service.rs1use std::{
11 sync::{
12 Arc,
13 Mutex,
14 },
15 thread,
16 time::Duration,
17};
18
19use qubit_function::{
20 Callable,
21 Runnable,
22};
23
24use qubit_executor::{
25 TaskHandle,
26 TrackedTask,
27 task::spi::{
28 TaskEndpointPair,
29 TaskRunner,
30 },
31};
32
33use crate::tokio_executor_service_state::TokioExecutorServiceState;
34use crate::tokio_service_task_guard::TokioServiceTaskGuard;
35use qubit_executor::service::{
36 ExecutorService,
37 ExecutorServiceLifecycle,
38 StopReport,
39 SubmissionError,
40};
41
42#[derive(Default, Clone)]
47pub struct TokioExecutorService {
48 state: Arc<TokioExecutorServiceState>,
50}
51
52pub type TokioBlockingExecutorService = TokioExecutorService;
54
55impl TokioExecutorService {
56 #[inline]
62 pub fn new() -> Self {
63 Self::default()
64 }
65}
66
67impl ExecutorService for TokioExecutorService {
68 type ResultHandle<R, E>
69 = TaskHandle<R, E>
70 where
71 R: Send + 'static,
72 E: Send + 'static;
73
74 type TrackedHandle<R, E>
75 = TrackedTask<R, E>
76 where
77 R: Send + 'static,
78 E: Send + 'static;
79
80 fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
95 where
96 T: Runnable<E> + Send + 'static,
97 E: Send + 'static,
98 {
99 let submission_guard = self.state.lock_submission();
100 if self.state.is_not_running() {
101 return Err(SubmissionError::Shutdown);
102 }
103 self.state.active_tasks.inc();
104
105 let marker = Arc::new(());
106 let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
107 let handle = tokio::task::spawn_blocking(move || {
108 let _guard = guard;
109 let mut task = task;
110 let runner = TaskRunner::new(move || task.run());
111 let _ = runner.call::<(), E>();
112 });
113 self.state
114 .register_abort_handle(marker, handle.abort_handle(), || {});
115 drop(submission_guard);
116 Ok(())
117 }
118
119 fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
134 where
135 C: Callable<R, E> + Send + 'static,
136 R: Send + 'static,
137 E: Send + 'static,
138 {
139 let submission_guard = self.state.lock_submission();
140 if self.state.is_not_running() {
141 return Err(SubmissionError::Shutdown);
142 }
143 self.state.active_tasks.inc();
144
145 let (handle, completion) = TaskEndpointPair::new().into_parts();
146 completion.accept();
147 let completion = Arc::new(Mutex::new(Some(completion)));
148 let abort_completion = Arc::clone(&completion);
149 let marker = Arc::new(());
150 let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
151 let join_handle = tokio::task::spawn_blocking(move || {
152 let _guard = guard;
153 let completion = completion
154 .lock()
155 .unwrap_or_else(std::sync::PoisonError::into_inner)
156 .take();
157 if let Some(completion) = completion {
158 TaskRunner::new(task).run(completion);
159 }
160 });
161 self.state
162 .register_abort_handle(marker, join_handle.abort_handle(), move || {
163 let completion = abort_completion
164 .lock()
165 .unwrap_or_else(std::sync::PoisonError::into_inner)
166 .take();
167 if let Some(completion) = completion {
168 let _cancelled = completion.cancel_unstarted();
169 }
170 });
171 drop(submission_guard);
172 Ok(handle)
173 }
174
175 fn submit_tracked_callable<C, R, E>(
190 &self,
191 task: C,
192 ) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
193 where
194 C: Callable<R, E> + Send + 'static,
195 R: Send + 'static,
196 E: Send + 'static,
197 {
198 let submission_guard = self.state.lock_submission();
199 if self.state.is_not_running() {
200 return Err(SubmissionError::Shutdown);
201 }
202 self.state.active_tasks.inc();
203
204 let (handle, completion) = TaskEndpointPair::new().into_tracked_parts();
205 completion.accept();
206 let completion = Arc::new(Mutex::new(Some(completion)));
207 let abort_completion = Arc::clone(&completion);
208 let marker = Arc::new(());
209 let guard = TokioServiceTaskGuard::new(Arc::clone(&self.state), Arc::clone(&marker));
210 let join_handle = tokio::task::spawn_blocking(move || {
211 let _guard = guard;
212 let completion = completion
213 .lock()
214 .unwrap_or_else(std::sync::PoisonError::into_inner)
215 .take();
216 if let Some(completion) = completion {
217 TaskRunner::new(task).run(completion);
218 }
219 });
220 self.state
221 .register_abort_handle(marker, join_handle.abort_handle(), move || {
222 let completion = abort_completion
223 .lock()
224 .unwrap_or_else(std::sync::PoisonError::into_inner)
225 .take();
226 if let Some(completion) = completion {
227 let _cancelled = completion.cancel_unstarted();
228 }
229 });
230 drop(submission_guard);
231 Ok(handle)
232 }
233
234 fn shutdown(&self) {
239 let _guard = self.state.lock_submission();
240 self.state.shutdown();
241 self.state.notify_if_terminated();
242 }
243
244 fn stop(&self) -> StopReport {
254 let _guard = self.state.lock_submission();
255 self.state.stop();
256 let running = self.state.active_tasks.get();
257 let cancellation_count = self.state.abort_tracked_tasks();
258 self.state.notify_if_terminated();
259 StopReport::new(0, running, cancellation_count)
260 }
261
262 fn lifecycle(&self) -> ExecutorServiceLifecycle {
264 self.state.lifecycle()
265 }
266
267 fn is_not_running(&self) -> bool {
269 self.state.is_not_running()
270 }
271
272 fn is_terminated(&self) -> bool {
274 self.lifecycle() == ExecutorServiceLifecycle::Terminated
275 }
276
277 fn wait_termination(&self) {
279 while !self.is_terminated() {
280 thread::sleep(Duration::from_millis(1));
281 }
282 }
283}