js_utils/
spawn.rs

1//! Background task spawning.
2
3use futures::Future;
4use std::rc::Rc;
5use std::sync::Mutex;
6use std::task::{Poll, Waker};
7
8/// Spawns a new asynchronous task, returning a [`JoinHandle`] for it.
9pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
10where
11    F: Future + 'static,
12    F::Output: 'static,
13{
14    let join_handle = JoinHandle::new();
15    let join_handle_clone = join_handle.clone();
16    wasm_bindgen_futures::spawn_local(async move {
17        join_handle_clone.set_result(future.await);
18    });
19    join_handle
20}
21
22/// Task failed to execute to completion.
23///
24/// Currently can only be caused by cancellation.
25#[derive(Debug)]
26#[non_exhaustive]
27pub struct JoinError {}
28
29impl JoinError {
30    /// Returns true if the error was caused by the task being cancelled.
31    pub fn is_cancelled(&self) -> bool {
32        true
33    }
34}
35
36/// An owned permission to join on a task (await its termination).
37///
38/// This can be thought of as the equivalent of [`std::thread::JoinHandle`] for
39/// a task rather than a thread.
40///
41/// A `JoinHandle` *detaches* the associated task when it is dropped, which
42/// means that there is no longer any handle to the task, and no way to `join`
43/// on it.
44///
45/// This `struct` is created by the [`spawn`] function.
46#[derive(Debug)]
47pub struct JoinHandle<T> {
48    state: Rc<Mutex<State<T>>>,
49}
50
51impl<T> JoinHandle<T> {
52    fn new() -> Self {
53        JoinHandle {
54            state: Rc::new(Mutex::new(State::new())),
55        }
56    }
57
58    /// Abort the task associated with the handle.
59    ///
60    /// Awaiting a cancelled task might complete as usual if the task was
61    /// already completed at the time it was cancelled, but most likely it
62    /// will fail with a [cancelled] [`JoinError`].
63    ///
64    /// [cancelled]: method@crate::spawn::JoinError::is_cancelled
65    pub fn abort(&self) {
66        self.state.lock().unwrap().set_result(Err(JoinError {}));
67    }
68
69    /// Checks if the task associated with this `JoinHandle` has finished.
70    ///
71    /// Please note that this method can return `false` even if [`abort`] has been
72    /// called on the task. This is because the cancellation process may take
73    /// some time, and this method does not return `true` until it has
74    /// completed.
75    ///
76    /// [`abort`]: method@JoinHandle::abort
77    pub fn is_finished(&self) -> bool {
78        self.state.lock().unwrap().is_finished()
79    }
80
81    fn set_result(&self, value: T) {
82        self.state.lock().unwrap().set_result(Ok(value));
83    }
84
85    fn clone(&self) -> Self {
86        JoinHandle {
87            state: self.state.clone(),
88        }
89    }
90}
91
92#[derive(Debug)]
93struct State<T> {
94    result: Option<Result<T, JoinError>>,
95    waker: Option<Waker>,
96}
97
98impl<T> State<T> {
99    fn new() -> Self {
100        State {
101            result: None,
102            waker: None,
103        }
104    }
105
106    fn is_finished(&self) -> bool {
107        self.result.is_some()
108    }
109
110    fn set_result(&mut self, value: Result<T, JoinError>) {
111        if self.result.is_none() {
112            self.result = Some(value);
113            self.wake();
114        }
115    }
116
117    fn wake(&mut self) {
118        if let Some(waker) = self.waker.take() {
119            waker.wake();
120        }
121    }
122
123    fn update_waker(&mut self, waker: &Waker) {
124        if let Some(current_waker) = &self.waker {
125            if !waker.will_wake(current_waker) {
126                self.waker = Some(waker.clone());
127            }
128        } else {
129            self.waker = Some(waker.clone())
130        }
131    }
132}
133
134impl<T> Future for JoinHandle<T> {
135    type Output = Result<T, JoinError>;
136
137    fn poll(
138        self: std::pin::Pin<&mut Self>,
139        cx: &mut std::task::Context<'_>,
140    ) -> std::task::Poll<Self::Output> {
141        let mut state = self.state.lock().unwrap();
142        if let Some(value) = state.result.take() {
143            Poll::Ready(value)
144        } else {
145            state.update_waker(cx.waker());
146            Poll::Pending
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use std::time::Duration;
154
155    use wasm_bindgen_test::wasm_bindgen_test;
156
157    use crate::{sleep, spawn};
158
159    #[wasm_bindgen_test]
160    async fn test_spawn() {
161        let task_1 = spawn(async { 1 });
162        let task_2 = spawn(async { 2 });
163
164        sleep(Duration::from_secs(1)).await;
165
166        assert!(task_1.is_finished());
167        assert!(task_2.is_finished());
168
169        assert_eq!(task_1.await.unwrap(), 1);
170        assert_eq!(task_2.await.unwrap(), 2);
171    }
172
173    #[wasm_bindgen_test]
174    async fn test_abort() {
175        let task = spawn(async {
176            sleep(Duration::from_secs(10)).await;
177            1
178        });
179        task.abort();
180
181        assert!(task.await.unwrap_err().is_cancelled());
182    }
183}