widgetkit_runtime/
tasks.rs1use crate::internal::Dispatcher;
2use futures::Future;
3#[cfg(not(feature = "runtime-tokio"))]
4use futures::future::{AbortHandle, Abortable};
5#[cfg(not(feature = "runtime-tokio"))]
6use std::thread;
7use std::{collections::HashMap, pin::Pin};
8use widgetkit_core::TaskId;
9
10pub struct Tasks<'a, M> {
11 backend: &'a mut dyn TaskBackend<M>,
12}
13
14impl<'a, M> Tasks<'a, M>
15where
16 M: Send + 'static,
17{
18 pub(crate) fn new(backend: &'a mut dyn TaskBackend<M>) -> Self {
19 Self { backend }
20 }
21
22 pub fn spawn<F>(&mut self, future: F) -> TaskId
23 where
24 F: Future<Output = M> + Send + 'static,
25 {
26 self.backend.spawn_boxed(None, Box::pin(future))
27 }
28
29 pub fn spawn_named<F>(&mut self, name: impl Into<String>, future: F) -> TaskId
30 where
31 F: Future<Output = M> + Send + 'static,
32 {
33 self.backend
34 .spawn_boxed(Some(name.into()), Box::pin(future))
35 }
36
37 pub fn cancel(&mut self, task_id: TaskId) -> bool {
38 self.backend.cancel(task_id)
39 }
40
41 pub fn cancel_all(&mut self) {
42 self.backend.cancel_all();
43 }
44}
45
46pub(crate) type BoxedFuture<M> = Pin<Box<dyn Future<Output = M> + Send + 'static>>;
47
48pub(crate) trait TaskBackend<M>: Send {
49 fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId;
50 fn cancel(&mut self, task_id: TaskId) -> bool;
51 fn cancel_all(&mut self);
52 fn reap(&mut self, task_id: TaskId);
53 fn shutdown(&mut self);
54 #[cfg(test)]
55 fn active_count(&self) -> usize;
56}
57
58pub(crate) fn task_backend<M>(dispatcher: Dispatcher<M>) -> Box<dyn TaskBackend<M>>
59where
60 M: Send + 'static,
61{
62 #[cfg(feature = "runtime-tokio")]
63 {
64 return Box::new(TokioTaskBackend::new(dispatcher));
65 }
66
67 #[cfg(not(feature = "runtime-tokio"))]
68 {
69 Box::new(DefaultTaskBackend::new(dispatcher))
70 }
71}
72
73#[cfg(not(feature = "runtime-tokio"))]
74struct DefaultTaskBackend<M> {
75 dispatcher: Dispatcher<M>,
76 tasks: HashMap<TaskId, DefaultTaskControl>,
77 shutting_down: bool,
78}
79
80#[cfg(not(feature = "runtime-tokio"))]
81struct DefaultTaskControl {
82 #[allow(dead_code)]
83 name: Option<String>,
84 abort_handle: AbortHandle,
85}
86
87#[cfg(not(feature = "runtime-tokio"))]
88impl<M> DefaultTaskBackend<M>
89where
90 M: Send + 'static,
91{
92 fn new(dispatcher: Dispatcher<M>) -> Self {
93 Self {
94 dispatcher,
95 tasks: HashMap::new(),
96 shutting_down: false,
97 }
98 }
99
100 fn close(&mut self) {
101 if self.shutting_down {
102 return;
103 }
104 self.shutting_down = true;
105 self.cancel_all();
106 }
107}
108
109#[cfg(not(feature = "runtime-tokio"))]
110impl<M> TaskBackend<M> for DefaultTaskBackend<M>
111where
112 M: Send + 'static,
113{
114 fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId {
115 let task_id = TaskId::new();
116 if self.shutting_down {
117 drop(future);
118 self.dispatcher.finish_task(task_id);
119 return task_id;
120 }
121
122 let (abort_handle, abort_registration) = AbortHandle::new_pair();
123 let dispatcher = self.dispatcher.clone();
124 thread::spawn(move || {
125 let future = Abortable::new(future, abort_registration);
126 if let Ok(message) = futures::executor::block_on(future) {
127 let _ = dispatcher.post_message(message);
128 }
129 dispatcher.finish_task(task_id);
130 });
131 self.tasks
132 .insert(task_id, DefaultTaskControl { name, abort_handle });
133 task_id
134 }
135
136 fn cancel(&mut self, task_id: TaskId) -> bool {
137 if let Some(control) = self.tasks.remove(&task_id) {
138 control.abort_handle.abort();
139 return true;
140 }
141 false
142 }
143
144 fn cancel_all(&mut self) {
145 for (_, control) in self.tasks.drain() {
146 control.abort_handle.abort();
147 }
148 }
149
150 fn reap(&mut self, task_id: TaskId) {
151 let _ = self.tasks.remove(&task_id);
152 }
153
154 fn shutdown(&mut self) {
155 self.close();
156 }
157
158 #[cfg(test)]
159 fn active_count(&self) -> usize {
160 self.tasks.len()
161 }
162}
163
164#[cfg(not(feature = "runtime-tokio"))]
165impl<M> Drop for DefaultTaskBackend<M> {
166 fn drop(&mut self) {
167 self.shutting_down = true;
168 for (_, control) in self.tasks.drain() {
169 control.abort_handle.abort();
170 }
171 }
172}
173
174#[cfg(feature = "runtime-tokio")]
175struct TokioTaskBackend<M> {
176 dispatcher: Dispatcher<M>,
177 runtime: tokio::runtime::Runtime,
178 tasks: HashMap<TaskId, TokioTaskControl>,
179 shutting_down: bool,
180}
181
182#[cfg(feature = "runtime-tokio")]
183struct TokioTaskControl {
184 #[allow(dead_code)]
185 name: Option<String>,
186 join_handle: tokio::task::JoinHandle<()>,
187}
188
189#[cfg(feature = "runtime-tokio")]
190impl<M> TokioTaskBackend<M>
191where
192 M: Send + 'static,
193{
194 fn new(dispatcher: Dispatcher<M>) -> Self {
195 let runtime = tokio::runtime::Builder::new_multi_thread()
196 .enable_all()
197 .build()
198 .expect("tokio runtime backend must initialize");
199 Self {
200 dispatcher,
201 runtime,
202 tasks: HashMap::new(),
203 shutting_down: false,
204 }
205 }
206
207 fn close(&mut self) {
208 if self.shutting_down {
209 return;
210 }
211 self.shutting_down = true;
212 self.cancel_all();
213 }
214}
215
216#[cfg(feature = "runtime-tokio")]
217impl<M> TaskBackend<M> for TokioTaskBackend<M>
218where
219 M: Send + 'static,
220{
221 fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId {
222 let task_id = TaskId::new();
223 if self.shutting_down {
224 drop(future);
225 self.dispatcher.finish_task(task_id);
226 return task_id;
227 }
228
229 let dispatcher = self.dispatcher.clone();
230 let join_handle = self.runtime.spawn(async move {
231 let message = future.await;
232 let _ = dispatcher.post_message(message);
233 dispatcher.finish_task(task_id);
234 });
235 self.tasks
236 .insert(task_id, TokioTaskControl { name, join_handle });
237 task_id
238 }
239
240 fn cancel(&mut self, task_id: TaskId) -> bool {
241 if let Some(control) = self.tasks.remove(&task_id) {
242 control.join_handle.abort();
243 return true;
244 }
245 false
246 }
247
248 fn cancel_all(&mut self) {
249 for (_, control) in self.tasks.drain() {
250 control.join_handle.abort();
251 }
252 }
253
254 fn reap(&mut self, task_id: TaskId) {
255 let _ = self.tasks.remove(&task_id);
256 }
257
258 fn shutdown(&mut self) {
259 self.close();
260 }
261
262 #[cfg(test)]
263 fn active_count(&self) -> usize {
264 self.tasks.len()
265 }
266}
267
268#[cfg(feature = "runtime-tokio")]
269impl<M> Drop for TokioTaskBackend<M> {
270 fn drop(&mut self) {
271 self.shutting_down = true;
272 for (_, control) in self.tasks.drain() {
273 control.join_handle.abort();
274 }
275 }
276}