local_spawn_pool/
task.rs

1use futures::channel::oneshot;
2use futures::channel::oneshot::{Receiver, Sender};
3use std::boxed::Box;
4use std::cell::{Cell, RefCell};
5use std::future::Future;
6use std::ops::{Deref, DerefMut};
7use std::pin::Pin;
8use std::rc::Rc;
9use std::task::Poll;
10
11pub fn create_task<F>(future: F) -> (Task, JoinHandle<F::Output>)
12where
13    F: Future + 'static,
14{
15    let (output_tx, output_rx) = oneshot::channel::<F::Output>();
16    let abort = Rc::new(Cell::new(false));
17
18    (
19        Task::from(GenericTask {
20            future: Box::pin(future),
21            output_tx: Some(output_tx),
22            abort: Rc::clone(&abort),
23        }),
24        JoinHandle(RefCell::new(JoinHandleInner::Pending {
25            output_rx: Box::pin(output_rx),
26            abort,
27        })),
28    )
29}
30
31pub struct Task(Pin<Box<dyn Future<Output = ()>>>);
32
33impl Deref for Task {
34    type Target = Pin<Box<dyn Future<Output = ()>>>;
35
36    fn deref(&self) -> &Self::Target {
37        &self.0
38    }
39}
40
41impl DerefMut for Task {
42    fn deref_mut(&mut self) -> &mut Self::Target {
43        &mut self.0
44    }
45}
46
47#[cfg(test)]
48impl Future for Task {
49    type Output = ();
50
51    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
52        Future::poll(self.deref_mut().as_mut(), cx)
53    }
54}
55
56impl<F> From<GenericTask<F>> for Task
57where
58    F: Future + 'static,
59{
60    fn from(generic_task: GenericTask<F>) -> Self {
61        Self(Box::pin(generic_task))
62    }
63}
64
65struct GenericTask<F>
66where
67    F: Future + 'static,
68{
69    future: Pin<Box<F>>,
70    /// The only purpose of the `Option` is to be able to take ownership of the `Sender` in the `Future::poll` function.
71    output_tx: Option<Sender<F::Output>>,
72    abort: Rc<Cell<bool>>,
73}
74
75impl<F> Future for GenericTask<F>
76where
77    F: Future + 'static,
78{
79    type Output = ();
80
81    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
82        if self.abort.get() {
83            Poll::Ready(())
84        } else {
85            match Future::poll(self.future.as_mut(), cx) {
86                Poll::Ready(value) => {
87                    let _ = self.output_tx.take().unwrap().send(value);
88                    Poll::Ready(())
89                }
90
91                Poll::Pending => Poll::Pending,
92            }
93        }
94    }
95}
96
97/// An owned permission to join on a task (await its termination).
98///
99/// This can be thought of as the equivalent of [`std::thread::JoinHandle`] for a [`crate::LocalSpawnPool`] task rather than a thread. You
100/// do not need to `.await` the [`JoinHandle`] to make the task execute — it will start running in the background immediately. When
101/// awaiting the `JoinHandle<T>`, you will obtain an `Option<T>`, where `T` is the output of the spawned future associated with this
102/// handle: it will be `None` if the task was aborted.
103///
104/// A [`JoinHandle`] detaches the associated task when it is dropped, which means that there is no longer any handle to the task,
105/// and no way to `join` on it.
106///
107/// This `struct` is created by the [`crate::LocalSpawnPool::spawn`] and [`crate::spawn`] functions.
108///
109/// ## Cancel safety
110///
111/// The `JoinHandle<T>` type is cancel safe. If it is used as the event in a `tokio::select!` statement and some other branch
112/// completes first, then it is guaranteed that the output of the task is not lost.
113///
114/// If a [`JoinHandle`] is dropped, then the task continues running in the background and its return value is lost.
115pub struct JoinHandle<T>(RefCell<JoinHandleInner<T>>);
116
117enum JoinHandleInner<T> {
118    Pending {
119        output_rx: Pin<Box<Receiver<T>>>,
120        abort: Rc<Cell<bool>>,
121    },
122    Finished(
123        /// The only purpose of the `Option` is to be able to take ownership of `T` in the `Future::poll` function.
124        /// It should always always be `Some` before the `Future::poll` returns `Poll::Ready`.
125        Option<T>,
126    ),
127    Aborted,
128}
129
130impl<T> JoinHandle<T> {
131    fn poll(&self) {
132        let mut inner = self.0.borrow_mut();
133
134        if let JoinHandleInner::Pending {
135            output_rx,
136            abort: _,
137        } = &mut *inner
138        {
139            match output_rx.try_recv() {
140                Ok(Some(value)) => *inner = JoinHandleInner::Finished(Some(value)),
141                Ok(None) => { /* Still pending */ }
142                Err(_) => *inner = JoinHandleInner::Aborted,
143            }
144        }
145    }
146
147    /// Aborts the task associated with the handle.
148    pub fn abort(&self) {
149        let mut inner = self.0.borrow_mut();
150
151        if let JoinHandleInner::Pending {
152            output_rx: _,
153            abort,
154        } = &*inner
155        {
156            abort.set(true);
157            *inner = JoinHandleInner::Aborted;
158        }
159    }
160
161    /// Returns `true` if the task has finished executing. If the task was aborted before finishing execution, it returns `false`.
162    pub fn is_finished(&self) -> bool {
163        self.poll();
164        matches!(&*self.0.borrow(), JoinHandleInner::Finished(_))
165    }
166
167    /// Returns `true` if the task has been aborted. If [`JoinHandle::abort`] was called after the task finished executing, it still
168    /// returns `false`.
169    pub fn is_aborted(&self) -> bool {
170        self.poll();
171        matches!(&*self.0.borrow(), JoinHandleInner::Aborted)
172    }
173}
174
175impl<T> Future for JoinHandle<T> {
176    type Output = Option<T>;
177
178    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
179        let mut inner = self.0.borrow_mut();
180
181        match &mut *inner {
182            JoinHandleInner::Pending {
183                output_rx,
184                abort: _,
185            } => match Future::poll(output_rx.as_mut(), cx) {
186                Poll::Ready(Ok(value)) => {
187                    *inner = JoinHandleInner::Finished(None);
188                    Poll::Ready(Some(value))
189                }
190
191                Poll::Ready(Err(_)) => {
192                    *inner = JoinHandleInner::Aborted;
193                    Poll::Ready(None)
194                }
195
196                Poll::Pending => Poll::Pending,
197            },
198
199            JoinHandleInner::Finished(value) => Poll::Ready(value.take()),
200            JoinHandleInner::Aborted => Poll::Ready(None),
201        }
202    }
203}
204
205#[cfg(test)]
206#[tokio::test]
207async fn test() {
208    use std::time::Duration;
209    use tokio::task::LocalSet;
210    use tokio::time;
211
212    let local_set = LocalSet::new();
213
214    local_set
215        .run_until(async {
216            let (task, join_handle) = create_task(async {
217                time::sleep(Duration::from_millis(50)).await;
218                "test"
219            });
220            tokio::task::spawn_local(task);
221            assert!(!join_handle.is_finished());
222            assert!(!join_handle.is_aborted());
223            assert_eq!(join_handle.await, Some("test"));
224
225            //
226
227            let (task, join_handle) = create_task(async {
228                time::sleep(Duration::from_millis(50)).await;
229                "test"
230            });
231            tokio::task::spawn_local(task);
232            time::sleep(Duration::from_millis(100)).await;
233            assert!(join_handle.is_finished());
234            assert!(!join_handle.is_aborted());
235            join_handle.abort();
236            assert!(join_handle.is_finished());
237            assert!(!join_handle.is_aborted());
238            assert_eq!(join_handle.await, Some("test"));
239
240            //
241
242            let (task, join_handle) = create_task(async {
243                time::sleep(Duration::from_millis(50)).await;
244                "test"
245            });
246            tokio::task::spawn_local(task);
247            assert!(!join_handle.is_finished());
248            assert!(!join_handle.is_aborted());
249            join_handle.abort();
250            assert!(!join_handle.is_finished());
251            assert!(join_handle.is_aborted());
252            assert_eq!(join_handle.await, None);
253
254            //
255
256            let (task, join_handle) = create_task(async {
257                time::sleep(Duration::from_millis(500)).await;
258                "test"
259            });
260            let tokio_join_handle = tokio::task::spawn_local(task);
261            assert!(!join_handle.is_finished());
262            assert!(!join_handle.is_aborted());
263            tokio_join_handle.abort();
264            time::sleep(Duration::from_millis(100)).await;
265            assert!(!join_handle.is_finished());
266            assert!(join_handle.is_aborted());
267            assert_eq!(join_handle.await, None);
268
269            //
270
271            let value = Rc::new(Cell::new(0i32));
272            let (task, join_handle) = create_task({
273                let value = Rc::clone(&value);
274                async move {
275                    time::sleep(Duration::from_millis(50)).await;
276                    value.set(1);
277                    "test"
278                }
279            });
280            tokio::task::spawn_local(task);
281            assert!(!join_handle.is_finished());
282            assert!(!join_handle.is_aborted());
283            drop(join_handle);
284            assert_eq!(value.get(), 0);
285            time::sleep(Duration::from_millis(100)).await;
286            assert_eq!(value.get(), 1);
287        })
288        .await;
289
290    local_set.await;
291}