mezzenger_webworker/
lib.rs

1//! Transport for communication with
2//! [Web Workers](https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API).
3//!
4//! See [repository](https://github.com/zduny/mezzenger) for more info.
5
6use std::{
7    cell::RefCell,
8    collections::VecDeque,
9    fmt::{Debug, Display},
10    marker::PhantomData,
11    pin::Pin,
12    rc::Rc,
13    task::{Context, Poll, Waker},
14};
15
16use futures::{stream::FusedStream, Sink, Stream};
17use js_sys::Uint8Array;
18use js_utils::{
19    event::{EventListener, When},
20    JsError, Queue,
21};
22use kodec::{Decode, Encode};
23use serde::{Deserialize, Serialize};
24use wasm_bindgen::{JsCast, JsValue};
25use web_sys::{DedicatedWorkerGlobalScope, Event, EventTarget, MessageEvent, Worker};
26
27pub trait PostMessage {
28    fn post_message(&self, message: &JsValue) -> Result<(), JsValue>;
29}
30
31impl PostMessage for Worker {
32    fn post_message(&self, message: &JsValue) -> Result<(), JsValue> {
33        self.post_message(message)
34    }
35}
36
37impl PostMessage for DedicatedWorkerGlobalScope {
38    fn post_message(&self, message: &JsValue) -> Result<(), JsValue> {
39        self.post_message(message)
40    }
41}
42
43#[derive(Debug)]
44pub enum Error<SerializationError, DeserializationError> {
45    SendingError(JsError),
46    SerializationError(SerializationError),
47    DeserializationError(DeserializationError),
48    WorkerError(Event),
49    MessageError(MessageEvent),
50}
51
52impl<SerializationError, DeserializationError> Display
53    for Error<SerializationError, DeserializationError>
54where
55    SerializationError: Display,
56    DeserializationError: Display,
57{
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            Error::SendingError(error) => write!(f, "failed to send message: {error}"),
61            Error::SerializationError(error) => write!(f, "failed to serialize message: {error}"),
62            Error::DeserializationError(error) => {
63                write!(f, "failed to deserialize message: {error}")
64            }
65            Error::WorkerError(error) => write!(f, "error occurred in worker: {error:?}"),
66            Error::MessageError(error) => write!(f, "message error occurred: {error:?}"),
67        }
68    }
69}
70
71impl<SerializationError, DeserializationError> std::error::Error
72    for Error<SerializationError, DeserializationError>
73where
74    SerializationError: Debug + Display,
75    DeserializationError: Debug + Display,
76{
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80enum Wrapper<Message> {
81    Open,
82    Message(Message),
83    Close,
84}
85
86struct State<Incoming, Error> {
87    incoming: VecDeque<Result<Incoming, Error>>,
88    waker: Option<Waker>,
89    closed: bool,
90}
91
92impl<Incoming, Error> State<Incoming, Error> {
93    fn new() -> Self {
94        State {
95            incoming: VecDeque::new(),
96            waker: None,
97            closed: false,
98        }
99    }
100
101    fn message(&mut self, message: Incoming) {
102        self.incoming.push_back(Ok(message));
103        self.wake();
104    }
105
106    fn error(&mut self, error: Error) {
107        self.incoming.push_back(Err(error));
108        self.wake();
109    }
110
111    fn close(&mut self) {
112        self.closed = true;
113        self.wake();
114    }
115
116    fn update_waker_with(&mut self, other: &Waker) {
117        if let Some(waker) = &self.waker {
118            if !waker.will_wake(other) {
119                self.waker = Some(other.clone());
120            }
121        } else {
122            self.waker = Some(other.clone());
123        }
124    }
125
126    fn wake(&mut self) {
127        if let Some(waker) = self.waker.take() {
128            waker.wake();
129        }
130    }
131}
132
133impl<Incoming, Error> Drop for State<Incoming, Error> {
134    fn drop(&mut self) {
135        if !self.closed {
136            self.close();
137        }
138    }
139}
140
141/// Transport for communication with
142/// [Web Workers](https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API).
143pub struct Transport<T, Codec, Incoming, Outgoing>
144where
145    T: AsRef<EventTarget> + PostMessage,
146    Codec: kodec::Codec,
147{
148    target: Rc<T>,
149    codec: Codec,
150    #[allow(clippy::type_complexity)]
151    state: Rc<RefCell<State<Incoming, Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>>>,
152    buffer: RefCell<Vec<u8>>,
153    _message_listener: EventListener<T, MessageEvent>,
154    _error_listener: EventListener<T, Event>,
155    _message_error_listener: EventListener<T, MessageEvent>,
156    _outgoing: PhantomData<Outgoing>,
157}
158
159impl<T, Codec, Incoming, Outgoing> Transport<T, Codec, Incoming, Outgoing>
160where
161    T: AsRef<EventTarget> + PostMessage,
162    Codec: 'static + kodec::Codec + Clone,
163    Incoming: 'static,
164    Outgoing: 'static + Serialize,
165    <Codec as Encode>::Error: 'static,
166    <Codec as Decode>::Error: 'static,
167    for<'de> Incoming: serde::de::Deserialize<'de>,
168{
169    async fn new_inner(target: &Rc<T>, codec: Codec, is_worker: bool) -> Result<Self, JsError> {
170        let open_notifier = Rc::new(Queue::new());
171
172        let target = target.clone();
173        let codec_clone = codec.clone();
174        let state = Rc::new(RefCell::new(State::new()));
175        let state_clone = state.clone();
176        let open_notifier_clone = Rc::downgrade(&open_notifier);
177        let message_listener = target.when("message", move |event: MessageEvent| {
178            let array = Uint8Array::new(&event.data());
179            let vector = array.to_vec();
180            let result: Result<Wrapper<Incoming>, _> = codec_clone.decode(&vector[..]);
181            match result {
182                Ok(message) => match message {
183                    Wrapper::Open => {
184                        if let Some(notifier) = open_notifier_clone.upgrade() {
185                            notifier.push(());
186                        } else {
187                            unreachable!("open message received twice!")
188                        }
189                    }
190                    Wrapper::Message(message) => state_clone.borrow_mut().message(message),
191                    Wrapper::Close => state_clone.borrow_mut().close(),
192                },
193                Err(error) => state_clone
194                    .borrow_mut()
195                    .error(Error::DeserializationError(error)),
196            }
197        })?;
198        let state_clone = state.clone();
199        let error_listener = target.when("error", move |event: Event| {
200            state_clone.borrow_mut().error(Error::WorkerError(event));
201        })?;
202        let state_clone = state.clone();
203        let message_error_listener = target.when("messageerror", move |event: MessageEvent| {
204            state_clone.borrow_mut().error(Error::MessageError(event));
205        })?;
206        let buffer = RefCell::new(vec![]);
207        let transport = Transport {
208            target,
209            codec,
210            state,
211            buffer,
212            _message_listener: message_listener,
213            _error_listener: error_listener,
214            _message_error_listener: message_error_listener,
215            _outgoing: PhantomData,
216        };
217
218        if is_worker {
219            let _ = transport.send_inner(Wrapper::Open);
220            open_notifier.pop().await;
221        } else {
222            open_notifier.pop().await;
223            let _ = transport.send_inner(Wrapper::Open);
224        }
225
226        Ok(transport)
227    }
228
229    fn send_inner(
230        &self,
231        message: Wrapper<Outgoing>,
232    ) -> Result<(), Error<<Codec as Encode>::Error, <Codec as Decode>::Error>> {
233        let mut buffer = self.buffer.borrow_mut();
234        self.codec
235            .encode(&mut *buffer, &message)
236            .map_err(Error::SerializationError)?;
237        let js_array = Uint8Array::from(&buffer[..]);
238        self.target
239            .post_message(&js_array)
240            .map_err(|error| Error::SendingError(error.into()))?;
241        buffer.clear();
242        Ok(())
243    }
244}
245
246impl<Codec, Incoming, Outgoing> Transport<Worker, Codec, Incoming, Outgoing>
247where
248    Codec: 'static + kodec::Codec + Clone,
249    Incoming: 'static,
250    Outgoing: 'static + Serialize,
251    <Codec as Encode>::Error: 'static,
252    <Codec as Decode>::Error: 'static,
253    for<'de> Incoming: serde::de::Deserialize<'de>,
254{
255    /// Create new transport for communication with worker.
256    pub async fn new(worker: &Rc<Worker>, codec: Codec) -> Result<Self, JsError> {
257        Transport::new_inner(worker, codec, false).await
258    }
259}
260
261impl<Codec, Incoming, Outgoing> Transport<DedicatedWorkerGlobalScope, Codec, Incoming, Outgoing>
262where
263    Codec: 'static + kodec::Codec + Clone,
264    Incoming: 'static,
265    Outgoing: 'static + Serialize,
266    <Codec as Encode>::Error: 'static,
267    <Codec as Decode>::Error: 'static,
268    for<'de> Incoming: serde::de::Deserialize<'de>,
269{
270    /// Create new transport inside worker.
271    ///
272    /// Will panic if called outside worker scope.
273    pub async fn new_in_worker(codec: Codec) -> Result<Self, JsError> {
274        let global = Rc::new(
275            js_sys::global()
276                .dyn_into::<DedicatedWorkerGlobalScope>()
277                .unwrap(),
278        );
279        Transport::new_inner(&global, codec, true).await
280    }
281}
282
283impl<T, Codec, Incoming, Outgoing> Sink<Outgoing> for Transport<T, Codec, Incoming, Outgoing>
284where
285    T: AsRef<EventTarget> + PostMessage,
286    Codec: 'static + kodec::Codec + Clone,
287    Incoming: 'static,
288    Outgoing: 'static + Serialize,
289    <Codec as Encode>::Error: 'static,
290    <Codec as Decode>::Error: 'static,
291    for<'de> Incoming: serde::de::Deserialize<'de>,
292{
293    type Error = mezzenger::Error<Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
294
295    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
296        if self.state.borrow().closed {
297            Poll::Ready(Err(mezzenger::Error::Closed))
298        } else {
299            Poll::Ready(Ok(()))
300        }
301    }
302
303    fn start_send(self: Pin<&mut Self>, item: Outgoing) -> Result<(), Self::Error> {
304        if self.state.borrow().closed {
305            Err(mezzenger::Error::Closed)
306        } else {
307            self.send_inner(Wrapper::Message(item))
308                .map_err(mezzenger::Error::Other)
309        }
310    }
311
312    fn poll_flush(
313        self: Pin<&mut Self>,
314        _cx: &mut Context<'_>,
315    ) -> std::task::Poll<Result<(), Self::Error>> {
316        if self.state.borrow().closed {
317            Poll::Ready(Err(mezzenger::Error::Closed))
318        } else {
319            Poll::Ready(Ok(()))
320        }
321    }
322
323    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
324        if self.state.borrow().closed {
325            Poll::Ready(Err(mezzenger::Error::Closed))
326        } else {
327            let _ = self.send_inner(Wrapper::Close);
328            self.state.borrow_mut().close();
329            Poll::Ready(Ok(()))
330        }
331    }
332}
333
334impl<T, Codec, Incoming, Outgoing> Stream for Transport<T, Codec, Incoming, Outgoing>
335where
336    T: AsRef<EventTarget> + PostMessage,
337    Codec: kodec::Codec,
338{
339    type Item = Result<Incoming, Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
340
341    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
342        let mut state = self.state.borrow_mut();
343        if state.closed && state.incoming.is_empty() {
344            Poll::Ready(None)
345        } else if let Some(item) = state.incoming.pop_front() {
346            Poll::Ready(Some(item))
347        } else {
348            state.update_waker_with(cx.waker());
349            Poll::Pending
350        }
351    }
352}
353
354impl<T, Codec, Incoming, Outgoing> FusedStream for Transport<T, Codec, Incoming, Outgoing>
355where
356    T: AsRef<EventTarget> + PostMessage,
357    Codec: kodec::Codec,
358{
359    fn is_terminated(&self) -> bool {
360        let state = self.state.borrow();
361        state.closed && state.incoming.is_empty()
362    }
363}
364
365impl<T, Codec, Incoming, Outgoing> mezzenger::Reliable for Transport<T, Codec, Incoming, Outgoing>
366where
367    T: AsRef<EventTarget> + PostMessage,
368    Codec: kodec::Codec,
369{
370}
371
372impl<T, Codec, Incoming, Outgoing> mezzenger::Order for Transport<T, Codec, Incoming, Outgoing>
373where
374    T: AsRef<EventTarget> + PostMessage,
375    Codec: kodec::Codec,
376{
377}