1use futures::channel::oneshot::{self, Sender};
4use futures::{select, Future, FutureExt};
5use std::rc::Rc;
6use std::sync::Mutex;
7use std::task::{Poll, Waker};
8
9pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
11where
12 F: Future + 'static,
13 F::Output: 'static,
14{
15 let (cancel_sender, mut cancel_receiver) = oneshot::channel();
16 let join_handle = JoinHandle::new(cancel_sender);
17 let join_handle_clone = join_handle.clone();
18 wasm_bindgen_futures::spawn_local(async move {
19 select! {
20 result = future.fuse() => join_handle_clone.set_result(result),
21 _ = cancel_receiver => ()
22 }
23 });
24 join_handle
25}
26
27#[derive(Debug)]
31#[non_exhaustive]
32pub struct JoinError {}
33
34impl JoinError {
35 pub fn is_cancelled(&self) -> bool {
37 true
38 }
39}
40
41#[derive(Debug)]
52pub struct JoinHandle<T> {
53 state: Rc<Mutex<State<T>>>,
54 cancel_sender: Mutex<Option<Sender<()>>>,
55}
56
57impl<T> JoinHandle<T> {
58 fn new(cancel_sender: Sender<()>) -> Self {
59 JoinHandle {
60 state: Rc::new(Mutex::new(State::new())),
61 cancel_sender: Mutex::new(Some(cancel_sender)),
62 }
63 }
64
65 pub fn abort(&self) {
73 self.state.lock().unwrap().set_result(Err(JoinError {}));
74 if let Some(sender) = self.cancel_sender.lock().unwrap().take() {
75 let _ = sender.send(());
76 }
77 }
78
79 pub fn is_finished(&self) -> bool {
88 self.state.lock().unwrap().is_finished()
89 }
90
91 fn set_result(&self, value: T) {
92 self.state.lock().unwrap().set_result(Ok(value));
93 }
94
95 fn clone(&self) -> Self {
96 JoinHandle {
97 state: self.state.clone(),
98 cancel_sender: Mutex::new(None),
99 }
100 }
101}
102
103#[derive(Debug)]
104struct State<T> {
105 result: Option<Result<T, JoinError>>,
106 waker: Option<Waker>,
107}
108
109impl<T> State<T> {
110 fn new() -> Self {
111 State {
112 result: None,
113 waker: None,
114 }
115 }
116
117 fn is_finished(&self) -> bool {
118 self.result.is_some()
119 }
120
121 fn set_result(&mut self, value: Result<T, JoinError>) {
122 if self.result.is_none() {
123 self.result = Some(value);
124 self.wake();
125 }
126 }
127
128 fn wake(&mut self) {
129 if let Some(waker) = self.waker.take() {
130 waker.wake();
131 }
132 }
133
134 fn update_waker(&mut self, waker: &Waker) {
135 if let Some(current_waker) = &self.waker {
136 if !waker.will_wake(current_waker) {
137 self.waker = Some(waker.clone());
138 }
139 } else {
140 self.waker = Some(waker.clone())
141 }
142 }
143}
144
145impl<T> Future for JoinHandle<T> {
146 type Output = Result<T, JoinError>;
147
148 fn poll(
149 self: std::pin::Pin<&mut Self>,
150 cx: &mut std::task::Context<'_>,
151 ) -> std::task::Poll<Self::Output> {
152 let mut state = self.state.lock().unwrap();
153 if let Some(value) = state.result.take() {
154 Poll::Ready(value)
155 } else {
156 state.update_waker(cx.waker());
157 Poll::Pending
158 }
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use std::time::Duration;
165
166 use wasm_bindgen_test::wasm_bindgen_test;
167
168 use crate::{sleep, spawn};
169
170 #[wasm_bindgen_test]
171 async fn test_spawn() {
172 let task_1 = spawn(async { 1 });
173 let task_2 = spawn(async { 2 });
174
175 sleep(Duration::from_secs(1)).await;
176
177 assert!(task_1.is_finished());
178 assert!(task_2.is_finished());
179
180 assert_eq!(task_1.await.unwrap(), 1);
181 assert_eq!(task_2.await.unwrap(), 2);
182 }
183
184 #[wasm_bindgen_test]
185 async fn test_abort() {
186 let task = spawn(async {
187 sleep(Duration::from_secs(10)).await;
188 1
189 });
190 task.abort();
191
192 assert!(task.await.unwrap_err().is_cancelled());
193 }
194}