1use std::future::Future;
9
10use tokio::sync::mpsc;
11
12#[cfg(not(web))]
13mod implementation {
14    use super::*;
15
16    pub trait Post: Send + Sync {}
20
21    impl<T: Send + Sync> Post for T {}
22
23    pub type NoInput = std::convert::Infallible;
25
26    pub type NonBlockingFuture<R> = tokio::task::JoinHandle<R>;
28    pub type BlockingFuture<R> = tokio::task::JoinHandle<R>;
30    pub type InputReceiver<T> = tokio_stream::wrappers::UnboundedReceiverStream<T>;
32    pub use mpsc::error::SendError;
34
35    pub fn spawn<F: Future<Output: Send> + Send + 'static>(
37        future: F,
38    ) -> NonBlockingFuture<F::Output> {
39        tokio::task::spawn(future)
40    }
41
42    pub struct Blocking<Input = NoInput, Output = ()> {
44        sender: mpsc::UnboundedSender<Input>,
45        join_handle: tokio::task::JoinHandle<Output>,
46    }
47
48    impl<Input: Send + 'static, Output: Send + 'static> Blocking<Input, Output> {
49        pub async fn spawn<F: Future<Output = Output>>(
51            work: impl FnOnce(InputReceiver<Input>) -> F + Send + 'static,
52        ) -> Self {
53            let (sender, receiver) = mpsc::unbounded_channel();
54            Self {
55                sender,
56                join_handle: tokio::task::spawn_blocking(|| {
57                    futures::executor::block_on(work(receiver.into()))
58                }),
59            }
60        }
61
62        pub async fn join(self) -> Output {
64            self.join_handle.await.expect("task shouldn't be cancelled")
65        }
66
67        pub fn send(&self, message: Input) -> Result<(), SendError<Input>> {
69            self.sender.send(message)
70        }
71    }
72}
73
74#[cfg(web)]
75mod implementation {
76    use std::convert::TryFrom;
77
78    use futures::{channel::oneshot, stream, StreamExt as _};
79    use wasm_bindgen::prelude::*;
80    use web_sys::js_sys;
81
82    use super::*;
83    use crate::dyn_convert;
84
85    pub trait Post: dyn_convert::DynInto<JsValue> {}
90
91    impl<T: dyn_convert::DynInto<JsValue>> Post for T {}
92
93    pub enum NoInput {}
95
96    impl TryFrom<JsValue> for NoInput {
97        type Error = JsValue;
98        fn try_from(value: JsValue) -> Result<Self, JsValue> {
99            Err(value)
100        }
101    }
102
103    impl From<NoInput> for JsValue {
104        fn from(no_input: NoInput) -> Self {
105            match no_input {}
106        }
107    }
108
109    pub struct SendError<T>(T);
111
112    impl<T> std::fmt::Debug for SendError<T> {
113        fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
114            f.debug_struct("SendError").finish_non_exhaustive()
115        }
116    }
117
118    impl<T> std::fmt::Display for SendError<T> {
119        fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
120            write!(f, "send error")
121        }
122    }
123
124    impl<T> std::error::Error for SendError<T> {}
125
126    pub struct Blocking<Input = NoInput, Output = ()> {
128        join_handle: wasm_thread::JoinHandle<Output>,
129        _phantom: std::marker::PhantomData<fn(Input)>,
130    }
131
132    pub type InputReceiver<T> =
134        stream::Map<tokio_stream::wrappers::UnboundedReceiverStream<JsValue>, fn(JsValue) -> T>;
135
136    fn convert_or_panic<V, T: TryFrom<V, Error: std::fmt::Debug>>(value: V) -> T {
137        T::try_from(value).expect("type correctness should ensure this can be deserialized")
138    }
139
140    pub type NonBlockingFuture<R> = oneshot::Receiver<R>;
142
143    pub fn spawn<F: Future + 'static>(future: F) -> NonBlockingFuture<F::Output> {
145        let (send, recv) = oneshot::channel();
146        wasm_bindgen_futures::spawn_local(async {
147            let _ = send.send(future.await);
148        });
149        recv
150    }
151
152    impl<Input, Output> Blocking<Input, Output> {
153        pub async fn spawn<F: Future<Output = Output>>(
155            work: impl FnOnce(InputReceiver<Input>) -> F + Send + 'static,
156        ) -> Self
157        where
158            Input: Into<JsValue> + TryFrom<JsValue, Error: std::fmt::Debug>,
159            Output: Send + 'static,
160        {
161            let (ready_sender, ready_receiver) = oneshot::channel();
162            let join_handle = wasm_thread::Builder::new()
163                .spawn(|| async move {
164                    let (input_sender, input_receiver) = mpsc::unbounded_channel::<JsValue>();
165                    let input_receiver =
166                        tokio_stream::wrappers::UnboundedReceiverStream::new(input_receiver);
167                    let onmessage = wasm_bindgen::closure::Closure::<
168                        dyn FnMut(web_sys::MessageEvent) -> Result<(), JsError>,
169                    >::new(
170                        move |event: web_sys::MessageEvent| -> Result<(), JsError> {
171                            input_sender.send(event.data())?;
172                            Ok(())
173                        },
174                    );
175                    js_sys::global()
176                        .dyn_into::<web_sys::DedicatedWorkerGlobalScope>()
177                        .unwrap()
178                        .set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
179                    onmessage.forget(); ready_sender.send(()).unwrap();
181                    work(input_receiver.map(convert_or_panic::<JsValue, Input>)).await
182                })
183                .expect("should successfully start Web Worker");
184
185            ready_receiver
186                .await
187                .expect("should successfully initialize the worker thread");
188            Self {
189                join_handle,
190                _phantom: Default::default(),
191            }
192        }
193
194        pub fn send(&self, message: Input) -> Result<(), SendError<Input>>
197        where
198            Input: Into<JsValue> + TryFrom<JsValue> + Clone,
199        {
200            self.join_handle
201                .thread()
202                .post_message(&message.clone().into())
203                .map_err(|_| SendError(message))
204        }
205
206        pub async fn join(self) -> Output {
208            match self.join_handle.join_async().await {
209                Ok(output) => output,
210                Err(panic) => std::panic::resume_unwind(panic),
211            }
212        }
213    }
214}
215
216pub use implementation::*;