Skip to main content

widgetkit_runtime/
tasks.rs

1use crate::internal::Dispatcher;
2use futures::Future;
3#[cfg(not(feature = "runtime-tokio"))]
4use futures::future::{AbortHandle, Abortable};
5use std::{collections::HashMap, pin::Pin};
6#[cfg(not(feature = "runtime-tokio"))]
7use std::thread;
8use widgetkit_core::TaskId;
9
10pub struct Tasks<'a, M> {
11    backend: &'a mut dyn TaskBackend<M>,
12}
13
14impl<'a, M> Tasks<'a, M>
15where
16    M: Send + 'static,
17{
18    pub(crate) fn new(backend: &'a mut dyn TaskBackend<M>) -> Self {
19        Self { backend }
20    }
21
22    pub fn spawn<F>(&mut self, future: F) -> TaskId
23    where
24        F: Future<Output = M> + Send + 'static,
25    {
26        self.backend.spawn_boxed(None, Box::pin(future))
27    }
28
29    pub fn spawn_named<F>(&mut self, name: impl Into<String>, future: F) -> TaskId
30    where
31        F: Future<Output = M> + Send + 'static,
32    {
33        self.backend.spawn_boxed(Some(name.into()), Box::pin(future))
34    }
35
36    pub fn cancel(&mut self, task_id: TaskId) -> bool {
37        self.backend.cancel(task_id)
38    }
39
40    pub fn cancel_all(&mut self) {
41        self.backend.cancel_all();
42    }
43}
44
45pub(crate) type BoxedFuture<M> = Pin<Box<dyn Future<Output = M> + Send + 'static>>;
46
47pub(crate) trait TaskBackend<M>: Send {
48    fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId;
49    fn cancel(&mut self, task_id: TaskId) -> bool;
50    fn cancel_all(&mut self);
51    fn reap(&mut self, task_id: TaskId);
52    fn shutdown(&mut self);
53    #[cfg(test)]
54    fn active_count(&self) -> usize;
55}
56
57pub(crate) fn task_backend<M>(dispatcher: Dispatcher<M>) -> Box<dyn TaskBackend<M>>
58where
59    M: Send + 'static,
60{
61    #[cfg(feature = "runtime-tokio")]
62    {
63        return Box::new(TokioTaskBackend::new(dispatcher));
64    }
65
66    #[cfg(not(feature = "runtime-tokio"))]
67    {
68        Box::new(DefaultTaskBackend::new(dispatcher))
69    }
70}
71
72#[cfg(not(feature = "runtime-tokio"))]
73struct DefaultTaskBackend<M> {
74    dispatcher: Dispatcher<M>,
75    tasks: HashMap<TaskId, DefaultTaskControl>,
76    shutting_down: bool,
77}
78
79#[cfg(not(feature = "runtime-tokio"))]
80struct DefaultTaskControl {
81    #[allow(dead_code)]
82    name: Option<String>,
83    abort_handle: AbortHandle,
84}
85
86#[cfg(not(feature = "runtime-tokio"))]
87impl<M> DefaultTaskBackend<M>
88where
89    M: Send + 'static,
90{
91    fn new(dispatcher: Dispatcher<M>) -> Self {
92        Self {
93            dispatcher,
94            tasks: HashMap::new(),
95            shutting_down: false,
96        }
97    }
98
99    fn close(&mut self) {
100        if self.shutting_down {
101            return;
102        }
103        self.shutting_down = true;
104        self.cancel_all();
105    }
106}
107
108#[cfg(not(feature = "runtime-tokio"))]
109impl<M> TaskBackend<M> for DefaultTaskBackend<M>
110where
111    M: Send + 'static,
112{
113    fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId {
114        let task_id = TaskId::new();
115        if self.shutting_down {
116            drop(future);
117            self.dispatcher.finish_task(task_id);
118            return task_id;
119        }
120
121        let (abort_handle, abort_registration) = AbortHandle::new_pair();
122        let dispatcher = self.dispatcher.clone();
123        thread::spawn(move || {
124            let future = Abortable::new(future, abort_registration);
125            if let Ok(message) = futures::executor::block_on(future) {
126                let _ = dispatcher.post_message(message);
127            }
128            dispatcher.finish_task(task_id);
129        });
130        self.tasks.insert(task_id, DefaultTaskControl { name, abort_handle });
131        task_id
132    }
133
134    fn cancel(&mut self, task_id: TaskId) -> bool {
135        if let Some(control) = self.tasks.remove(&task_id) {
136            control.abort_handle.abort();
137            return true;
138        }
139        false
140    }
141
142    fn cancel_all(&mut self) {
143        for (_, control) in self.tasks.drain() {
144            control.abort_handle.abort();
145        }
146    }
147
148    fn reap(&mut self, task_id: TaskId) {
149        let _ = self.tasks.remove(&task_id);
150    }
151
152    fn shutdown(&mut self) {
153        self.close();
154    }
155
156    #[cfg(test)]
157    fn active_count(&self) -> usize {
158        self.tasks.len()
159    }
160}
161
162#[cfg(not(feature = "runtime-tokio"))]
163impl<M> Drop for DefaultTaskBackend<M> {
164    fn drop(&mut self) {
165        self.shutting_down = true;
166        for (_, control) in self.tasks.drain() {
167            control.abort_handle.abort();
168        }
169    }
170}
171
172#[cfg(feature = "runtime-tokio")]
173struct TokioTaskBackend<M> {
174    dispatcher: Dispatcher<M>,
175    runtime: tokio::runtime::Runtime,
176    tasks: HashMap<TaskId, TokioTaskControl>,
177    shutting_down: bool,
178}
179
180#[cfg(feature = "runtime-tokio")]
181struct TokioTaskControl {
182    #[allow(dead_code)]
183    name: Option<String>,
184    join_handle: tokio::task::JoinHandle<()>,
185}
186
187#[cfg(feature = "runtime-tokio")]
188impl<M> TokioTaskBackend<M>
189where
190    M: Send + 'static,
191{
192    fn new(dispatcher: Dispatcher<M>) -> Self {
193        let runtime = tokio::runtime::Builder::new_multi_thread()
194            .enable_all()
195            .build()
196            .expect("tokio runtime backend must initialize");
197        Self {
198            dispatcher,
199            runtime,
200            tasks: HashMap::new(),
201            shutting_down: false,
202        }
203    }
204
205    fn close(&mut self) {
206        if self.shutting_down {
207            return;
208        }
209        self.shutting_down = true;
210        self.cancel_all();
211    }
212}
213
214#[cfg(feature = "runtime-tokio")]
215impl<M> TaskBackend<M> for TokioTaskBackend<M>
216where
217    M: Send + 'static,
218{
219    fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId {
220        let task_id = TaskId::new();
221        if self.shutting_down {
222            drop(future);
223            self.dispatcher.finish_task(task_id);
224            return task_id;
225        }
226
227        let dispatcher = self.dispatcher.clone();
228        let join_handle = self.runtime.spawn(async move {
229            let message = future.await;
230            let _ = dispatcher.post_message(message);
231            dispatcher.finish_task(task_id);
232        });
233        self.tasks.insert(task_id, TokioTaskControl { name, join_handle });
234        task_id
235    }
236
237    fn cancel(&mut self, task_id: TaskId) -> bool {
238        if let Some(control) = self.tasks.remove(&task_id) {
239            control.join_handle.abort();
240            return true;
241        }
242        false
243    }
244
245    fn cancel_all(&mut self) {
246        for (_, control) in self.tasks.drain() {
247            control.join_handle.abort();
248        }
249    }
250
251    fn reap(&mut self, task_id: TaskId) {
252        let _ = self.tasks.remove(&task_id);
253    }
254
255    fn shutdown(&mut self) {
256        self.close();
257    }
258
259    #[cfg(test)]
260    fn active_count(&self) -> usize {
261        self.tasks.len()
262    }
263}
264
265#[cfg(feature = "runtime-tokio")]
266impl<M> Drop for TokioTaskBackend<M> {
267    fn drop(&mut self) {
268        self.shutting_down = true;
269        for (_, control) in self.tasks.drain() {
270            control.join_handle.abort();
271        }
272    }
273}