js_utils/
spawn.rs

1//! Background task spawning.
2
3use futures::channel::oneshot::{self, Sender};
4use futures::{select, Future, FutureExt};
5use std::rc::Rc;
6use std::sync::Mutex;
7use std::task::{Poll, Waker};
8
9/// Spawns a new asynchronous task, returning a [`JoinHandle`] for it.
10pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
11where
12    F: Future + 'static,
13    F::Output: 'static,
14{
15    let (cancel_sender, mut cancel_receiver) = oneshot::channel();
16    let join_handle = JoinHandle::new(cancel_sender);
17    let join_handle_clone = join_handle.clone();
18    wasm_bindgen_futures::spawn_local(async move {
19        select! {
20            result = future.fuse() => join_handle_clone.set_result(result),
21            _ = cancel_receiver => ()
22        }
23    });
24    join_handle
25}
26
27/// Task failed to execute to completion.
28///
29/// Currently can only be caused by cancellation.
30#[derive(Debug)]
31#[non_exhaustive]
32pub struct JoinError {}
33
34impl JoinError {
35    /// Returns true if the error was caused by the task being cancelled.
36    pub fn is_cancelled(&self) -> bool {
37        true
38    }
39}
40
41/// An owned permission to join on a task (await its termination).
42///
43/// This can be thought of as the equivalent of [`std::thread::JoinHandle`] for
44/// a task rather than a thread.
45///
46/// A `JoinHandle` *detaches* the associated task when it is dropped, which
47/// means that there is no longer any handle to the task, and no way to `join`
48/// on it.
49///
50/// This `struct` is created by the [`spawn`] function.
51#[derive(Debug)]
52pub struct JoinHandle<T> {
53    state: Rc<Mutex<State<T>>>,
54    cancel_sender: Mutex<Option<Sender<()>>>,
55}
56
57impl<T> JoinHandle<T> {
58    fn new(cancel_sender: Sender<()>) -> Self {
59        JoinHandle {
60            state: Rc::new(Mutex::new(State::new())),
61            cancel_sender: Mutex::new(Some(cancel_sender)),
62        }
63    }
64
65    /// Abort the task associated with the handle.
66    ///
67    /// Awaiting a cancelled task might complete as usual if the task was
68    /// already completed at the time it was cancelled, but most likely it
69    /// will fail with a [cancelled] [`JoinError`].
70    ///
71    /// [cancelled]: method@crate::spawn::JoinError::is_cancelled
72    pub fn abort(&self) {
73        self.state.lock().unwrap().set_result(Err(JoinError {}));
74        if let Some(sender) = self.cancel_sender.lock().unwrap().take() {
75            let _ = sender.send(());
76        }
77    }
78
79    /// Checks if the task associated with this `JoinHandle` has finished.
80    ///
81    /// Please note that this method can return `false` even if [`abort`] has been
82    /// called on the task. This is because the cancellation process may take
83    /// some time, and this method does not return `true` until it has
84    /// completed.
85    ///
86    /// [`abort`]: method@JoinHandle::abort
87    pub fn is_finished(&self) -> bool {
88        self.state.lock().unwrap().is_finished()
89    }
90
91    fn set_result(&self, value: T) {
92        self.state.lock().unwrap().set_result(Ok(value));
93    }
94
95    fn clone(&self) -> Self {
96        JoinHandle {
97            state: self.state.clone(),
98            cancel_sender: Mutex::new(None),
99        }
100    }
101}
102
103#[derive(Debug)]
104struct State<T> {
105    result: Option<Result<T, JoinError>>,
106    waker: Option<Waker>,
107}
108
109impl<T> State<T> {
110    fn new() -> Self {
111        State {
112            result: None,
113            waker: None,
114        }
115    }
116
117    fn is_finished(&self) -> bool {
118        self.result.is_some()
119    }
120
121    fn set_result(&mut self, value: Result<T, JoinError>) {
122        if self.result.is_none() {
123            self.result = Some(value);
124            self.wake();
125        }
126    }
127
128    fn wake(&mut self) {
129        if let Some(waker) = self.waker.take() {
130            waker.wake();
131        }
132    }
133
134    fn update_waker(&mut self, waker: &Waker) {
135        if let Some(current_waker) = &self.waker {
136            if !waker.will_wake(current_waker) {
137                self.waker = Some(waker.clone());
138            }
139        } else {
140            self.waker = Some(waker.clone())
141        }
142    }
143}
144
145impl<T> Future for JoinHandle<T> {
146    type Output = Result<T, JoinError>;
147
148    fn poll(
149        self: std::pin::Pin<&mut Self>,
150        cx: &mut std::task::Context<'_>,
151    ) -> std::task::Poll<Self::Output> {
152        let mut state = self.state.lock().unwrap();
153        if let Some(value) = state.result.take() {
154            Poll::Ready(value)
155        } else {
156            state.update_waker(cx.waker());
157            Poll::Pending
158        }
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use std::time::Duration;
165
166    use wasm_bindgen_test::wasm_bindgen_test;
167
168    use crate::{sleep, spawn};
169
170    #[wasm_bindgen_test]
171    async fn test_spawn() {
172        let task_1 = spawn(async { 1 });
173        let task_2 = spawn(async { 2 });
174
175        sleep(Duration::from_secs(1)).await;
176
177        assert!(task_1.is_finished());
178        assert!(task_2.is_finished());
179
180        assert_eq!(task_1.await.unwrap(), 1);
181        assert_eq!(task_2.await.unwrap(), 2);
182    }
183
184    #[wasm_bindgen_test]
185    async fn test_abort() {
186        let task = spawn(async {
187            sleep(Duration::from_secs(10)).await;
188            1
189        });
190        task.abort();
191
192        assert!(task.await.unwrap_err().is_cancelled());
193    }
194}