1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5#[derive(Debug)]
6pub struct JoinError;
7
8impl std::fmt::Display for JoinError {
9    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
10        write!(f, "JoinError")
11    }
12}
13
14impl std::error::Error for JoinError {}
15
16pub struct JoinHandle<T>
17where
18    T: Send,
19{
20    handle: tokio::sync::oneshot::Receiver<T>,
21}
22
23#[cfg(target_arch = "wasm32")]
24macro_rules! spawn_impl {
25    ($fn:expr) => {
26        wasm_bindgen_futures::spawn_local($fn)
27    };
28}
29
30#[cfg(not(target_arch = "wasm32"))]
31macro_rules! spawn_impl {
32    ($fn:expr) => {
33        tokio::spawn($fn)
34    };
35}
36
37pub fn spawn<F, T>(future: F) -> JoinHandle<T>
38where
39    F: Future<Output = T> + 'static + Send,
40    T: Send + 'static,
41{
42    let (sender, receiver) = tokio::sync::oneshot::channel();
43    spawn_impl!(async {
44        let result = future.await;
45        let _ = sender.send(result);
46    });
47    JoinHandle { handle: receiver }
48}
49
50impl<T> Future for JoinHandle<T>
51where
52    T: Send,
53{
54    type Output = Result<T, JoinError>;
55
56    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
57        match Pin::new(&mut self.handle).poll(cx) {
58            Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)),
59            Poll::Ready(Err(_)) => Poll::Ready(Err(JoinError)),
60            Poll::Pending => Poll::Pending,
61        }
62    }
63}