1use futures::Future;
4use std::rc::Rc;
5use std::sync::Mutex;
6use std::task::{Poll, Waker};
7
8pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
10where
11 F: Future + 'static,
12 F::Output: 'static,
13{
14 let join_handle = JoinHandle::new();
15 let join_handle_clone = join_handle.clone();
16 wasm_bindgen_futures::spawn_local(async move {
17 join_handle_clone.set_result(future.await);
18 });
19 join_handle
20}
21
22#[derive(Debug)]
26#[non_exhaustive]
27pub struct JoinError {}
28
29impl JoinError {
30 pub fn is_cancelled(&self) -> bool {
32 true
33 }
34}
35
36#[derive(Debug)]
47pub struct JoinHandle<T> {
48 state: Rc<Mutex<State<T>>>,
49}
50
51impl<T> JoinHandle<T> {
52 fn new() -> Self {
53 JoinHandle {
54 state: Rc::new(Mutex::new(State::new())),
55 }
56 }
57
58 pub fn abort(&self) {
66 self.state.lock().unwrap().set_result(Err(JoinError {}));
67 }
68
69 pub fn is_finished(&self) -> bool {
78 self.state.lock().unwrap().is_finished()
79 }
80
81 fn set_result(&self, value: T) {
82 self.state.lock().unwrap().set_result(Ok(value));
83 }
84
85 fn clone(&self) -> Self {
86 JoinHandle {
87 state: self.state.clone(),
88 }
89 }
90}
91
92#[derive(Debug)]
93struct State<T> {
94 result: Option<Result<T, JoinError>>,
95 waker: Option<Waker>,
96}
97
98impl<T> State<T> {
99 fn new() -> Self {
100 State {
101 result: None,
102 waker: None,
103 }
104 }
105
106 fn is_finished(&self) -> bool {
107 self.result.is_some()
108 }
109
110 fn set_result(&mut self, value: Result<T, JoinError>) {
111 if self.result.is_none() {
112 self.result = Some(value);
113 self.wake();
114 }
115 }
116
117 fn wake(&mut self) {
118 if let Some(waker) = self.waker.take() {
119 waker.wake();
120 }
121 }
122
123 fn update_waker(&mut self, waker: &Waker) {
124 if let Some(current_waker) = &self.waker {
125 if !waker.will_wake(current_waker) {
126 self.waker = Some(waker.clone());
127 }
128 } else {
129 self.waker = Some(waker.clone())
130 }
131 }
132}
133
134impl<T> Future for JoinHandle<T> {
135 type Output = Result<T, JoinError>;
136
137 fn poll(
138 self: std::pin::Pin<&mut Self>,
139 cx: &mut std::task::Context<'_>,
140 ) -> std::task::Poll<Self::Output> {
141 let mut state = self.state.lock().unwrap();
142 if let Some(value) = state.result.take() {
143 Poll::Ready(value)
144 } else {
145 state.update_waker(cx.waker());
146 Poll::Pending
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use std::time::Duration;
154
155 use wasm_bindgen_test::wasm_bindgen_test;
156
157 use crate::{sleep, spawn};
158
159 #[wasm_bindgen_test]
160 async fn test_spawn() {
161 let task_1 = spawn(async { 1 });
162 let task_2 = spawn(async { 2 });
163
164 sleep(Duration::from_secs(1)).await;
165
166 assert!(task_1.is_finished());
167 assert!(task_2.is_finished());
168
169 assert_eq!(task_1.await.unwrap(), 1);
170 assert_eq!(task_2.await.unwrap(), 2);
171 }
172
173 #[wasm_bindgen_test]
174 async fn test_abort() {
175 let task = spawn(async {
176 sleep(Duration::from_secs(10)).await;
177 1
178 });
179 task.abort();
180
181 assert!(task.await.unwrap_err().is_cancelled());
182 }
183}