1use std::{
7 cell::RefCell,
8 collections::VecDeque,
9 fmt::{Debug, Display},
10 marker::PhantomData,
11 pin::Pin,
12 rc::Rc,
13 task::{Context, Poll, Waker},
14};
15
16use futures::{stream::FusedStream, Sink, Stream};
17use js_sys::Uint8Array;
18use js_utils::{
19 event::{EventListener, When},
20 JsError, Queue,
21};
22use kodec::{Decode, Encode};
23use serde::{Deserialize, Serialize};
24use wasm_bindgen::{JsCast, JsValue};
25use web_sys::{DedicatedWorkerGlobalScope, Event, EventTarget, MessageEvent, Worker};
26
27pub trait PostMessage {
28 fn post_message(&self, message: &JsValue) -> Result<(), JsValue>;
29}
30
31impl PostMessage for Worker {
32 fn post_message(&self, message: &JsValue) -> Result<(), JsValue> {
33 self.post_message(message)
34 }
35}
36
37impl PostMessage for DedicatedWorkerGlobalScope {
38 fn post_message(&self, message: &JsValue) -> Result<(), JsValue> {
39 self.post_message(message)
40 }
41}
42
43#[derive(Debug)]
44pub enum Error<SerializationError, DeserializationError> {
45 SendingError(JsError),
46 SerializationError(SerializationError),
47 DeserializationError(DeserializationError),
48 WorkerError(Event),
49 MessageError(MessageEvent),
50}
51
52impl<SerializationError, DeserializationError> Display
53 for Error<SerializationError, DeserializationError>
54where
55 SerializationError: Display,
56 DeserializationError: Display,
57{
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 Error::SendingError(error) => write!(f, "failed to send message: {error}"),
61 Error::SerializationError(error) => write!(f, "failed to serialize message: {error}"),
62 Error::DeserializationError(error) => {
63 write!(f, "failed to deserialize message: {error}")
64 }
65 Error::WorkerError(error) => write!(f, "error occurred in worker: {error:?}"),
66 Error::MessageError(error) => write!(f, "message error occurred: {error:?}"),
67 }
68 }
69}
70
71impl<SerializationError, DeserializationError> std::error::Error
72 for Error<SerializationError, DeserializationError>
73where
74 SerializationError: Debug + Display,
75 DeserializationError: Debug + Display,
76{
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80enum Wrapper<Message> {
81 Open,
82 Message(Message),
83 Close,
84}
85
86struct State<Incoming, Error> {
87 incoming: VecDeque<Result<Incoming, Error>>,
88 waker: Option<Waker>,
89 closed: bool,
90}
91
92impl<Incoming, Error> State<Incoming, Error> {
93 fn new() -> Self {
94 State {
95 incoming: VecDeque::new(),
96 waker: None,
97 closed: false,
98 }
99 }
100
101 fn message(&mut self, message: Incoming) {
102 self.incoming.push_back(Ok(message));
103 self.wake();
104 }
105
106 fn error(&mut self, error: Error) {
107 self.incoming.push_back(Err(error));
108 self.wake();
109 }
110
111 fn close(&mut self) {
112 self.closed = true;
113 self.wake();
114 }
115
116 fn update_waker_with(&mut self, other: &Waker) {
117 if let Some(waker) = &self.waker {
118 if !waker.will_wake(other) {
119 self.waker = Some(other.clone());
120 }
121 } else {
122 self.waker = Some(other.clone());
123 }
124 }
125
126 fn wake(&mut self) {
127 if let Some(waker) = self.waker.take() {
128 waker.wake();
129 }
130 }
131}
132
133impl<Incoming, Error> Drop for State<Incoming, Error> {
134 fn drop(&mut self) {
135 if !self.closed {
136 self.close();
137 }
138 }
139}
140
141pub struct Transport<T, Codec, Incoming, Outgoing>
144where
145 T: AsRef<EventTarget> + PostMessage,
146 Codec: kodec::Codec,
147{
148 target: Rc<T>,
149 codec: Codec,
150 #[allow(clippy::type_complexity)]
151 state: Rc<RefCell<State<Incoming, Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>>>,
152 buffer: RefCell<Vec<u8>>,
153 _message_listener: EventListener<T, MessageEvent>,
154 _error_listener: EventListener<T, Event>,
155 _message_error_listener: EventListener<T, MessageEvent>,
156 _outgoing: PhantomData<Outgoing>,
157}
158
159impl<T, Codec, Incoming, Outgoing> Transport<T, Codec, Incoming, Outgoing>
160where
161 T: AsRef<EventTarget> + PostMessage,
162 Codec: 'static + kodec::Codec + Clone,
163 Incoming: 'static,
164 Outgoing: 'static + Serialize,
165 <Codec as Encode>::Error: 'static,
166 <Codec as Decode>::Error: 'static,
167 for<'de> Incoming: serde::de::Deserialize<'de>,
168{
169 async fn new_inner(target: &Rc<T>, codec: Codec, is_worker: bool) -> Result<Self, JsError> {
170 let open_notifier = Rc::new(Queue::new());
171
172 let target = target.clone();
173 let codec_clone = codec.clone();
174 let state = Rc::new(RefCell::new(State::new()));
175 let state_clone = state.clone();
176 let open_notifier_clone = Rc::downgrade(&open_notifier);
177 let message_listener = target.when("message", move |event: MessageEvent| {
178 let array = Uint8Array::new(&event.data());
179 let vector = array.to_vec();
180 let result: Result<Wrapper<Incoming>, _> = codec_clone.decode(&vector[..]);
181 match result {
182 Ok(message) => match message {
183 Wrapper::Open => {
184 if let Some(notifier) = open_notifier_clone.upgrade() {
185 notifier.push(());
186 } else {
187 unreachable!("open message received twice!")
188 }
189 }
190 Wrapper::Message(message) => state_clone.borrow_mut().message(message),
191 Wrapper::Close => state_clone.borrow_mut().close(),
192 },
193 Err(error) => state_clone
194 .borrow_mut()
195 .error(Error::DeserializationError(error)),
196 }
197 })?;
198 let state_clone = state.clone();
199 let error_listener = target.when("error", move |event: Event| {
200 state_clone.borrow_mut().error(Error::WorkerError(event));
201 })?;
202 let state_clone = state.clone();
203 let message_error_listener = target.when("messageerror", move |event: MessageEvent| {
204 state_clone.borrow_mut().error(Error::MessageError(event));
205 })?;
206 let buffer = RefCell::new(vec![]);
207 let transport = Transport {
208 target,
209 codec,
210 state,
211 buffer,
212 _message_listener: message_listener,
213 _error_listener: error_listener,
214 _message_error_listener: message_error_listener,
215 _outgoing: PhantomData,
216 };
217
218 if is_worker {
219 let _ = transport.send_inner(Wrapper::Open);
220 open_notifier.pop().await;
221 } else {
222 open_notifier.pop().await;
223 let _ = transport.send_inner(Wrapper::Open);
224 }
225
226 Ok(transport)
227 }
228
229 fn send_inner(
230 &self,
231 message: Wrapper<Outgoing>,
232 ) -> Result<(), Error<<Codec as Encode>::Error, <Codec as Decode>::Error>> {
233 let mut buffer = self.buffer.borrow_mut();
234 self.codec
235 .encode(&mut *buffer, &message)
236 .map_err(Error::SerializationError)?;
237 let js_array = Uint8Array::from(&buffer[..]);
238 self.target
239 .post_message(&js_array)
240 .map_err(|error| Error::SendingError(error.into()))?;
241 buffer.clear();
242 Ok(())
243 }
244}
245
246impl<Codec, Incoming, Outgoing> Transport<Worker, Codec, Incoming, Outgoing>
247where
248 Codec: 'static + kodec::Codec + Clone,
249 Incoming: 'static,
250 Outgoing: 'static + Serialize,
251 <Codec as Encode>::Error: 'static,
252 <Codec as Decode>::Error: 'static,
253 for<'de> Incoming: serde::de::Deserialize<'de>,
254{
255 pub async fn new(worker: &Rc<Worker>, codec: Codec) -> Result<Self, JsError> {
257 Transport::new_inner(worker, codec, false).await
258 }
259}
260
261impl<Codec, Incoming, Outgoing> Transport<DedicatedWorkerGlobalScope, Codec, Incoming, Outgoing>
262where
263 Codec: 'static + kodec::Codec + Clone,
264 Incoming: 'static,
265 Outgoing: 'static + Serialize,
266 <Codec as Encode>::Error: 'static,
267 <Codec as Decode>::Error: 'static,
268 for<'de> Incoming: serde::de::Deserialize<'de>,
269{
270 pub async fn new_in_worker(codec: Codec) -> Result<Self, JsError> {
274 let global = Rc::new(
275 js_sys::global()
276 .dyn_into::<DedicatedWorkerGlobalScope>()
277 .unwrap(),
278 );
279 Transport::new_inner(&global, codec, true).await
280 }
281}
282
283impl<T, Codec, Incoming, Outgoing> Sink<Outgoing> for Transport<T, Codec, Incoming, Outgoing>
284where
285 T: AsRef<EventTarget> + PostMessage,
286 Codec: 'static + kodec::Codec + Clone,
287 Incoming: 'static,
288 Outgoing: 'static + Serialize,
289 <Codec as Encode>::Error: 'static,
290 <Codec as Decode>::Error: 'static,
291 for<'de> Incoming: serde::de::Deserialize<'de>,
292{
293 type Error = mezzenger::Error<Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
294
295 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
296 if self.state.borrow().closed {
297 Poll::Ready(Err(mezzenger::Error::Closed))
298 } else {
299 Poll::Ready(Ok(()))
300 }
301 }
302
303 fn start_send(self: Pin<&mut Self>, item: Outgoing) -> Result<(), Self::Error> {
304 if self.state.borrow().closed {
305 Err(mezzenger::Error::Closed)
306 } else {
307 self.send_inner(Wrapper::Message(item))
308 .map_err(mezzenger::Error::Other)
309 }
310 }
311
312 fn poll_flush(
313 self: Pin<&mut Self>,
314 _cx: &mut Context<'_>,
315 ) -> std::task::Poll<Result<(), Self::Error>> {
316 if self.state.borrow().closed {
317 Poll::Ready(Err(mezzenger::Error::Closed))
318 } else {
319 Poll::Ready(Ok(()))
320 }
321 }
322
323 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
324 if self.state.borrow().closed {
325 Poll::Ready(Err(mezzenger::Error::Closed))
326 } else {
327 let _ = self.send_inner(Wrapper::Close);
328 self.state.borrow_mut().close();
329 Poll::Ready(Ok(()))
330 }
331 }
332}
333
334impl<T, Codec, Incoming, Outgoing> Stream for Transport<T, Codec, Incoming, Outgoing>
335where
336 T: AsRef<EventTarget> + PostMessage,
337 Codec: kodec::Codec,
338{
339 type Item = Result<Incoming, Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
340
341 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
342 let mut state = self.state.borrow_mut();
343 if state.closed && state.incoming.is_empty() {
344 Poll::Ready(None)
345 } else if let Some(item) = state.incoming.pop_front() {
346 Poll::Ready(Some(item))
347 } else {
348 state.update_waker_with(cx.waker());
349 Poll::Pending
350 }
351 }
352}
353
354impl<T, Codec, Incoming, Outgoing> FusedStream for Transport<T, Codec, Incoming, Outgoing>
355where
356 T: AsRef<EventTarget> + PostMessage,
357 Codec: kodec::Codec,
358{
359 fn is_terminated(&self) -> bool {
360 let state = self.state.borrow();
361 state.closed && state.incoming.is_empty()
362 }
363}
364
365impl<T, Codec, Incoming, Outgoing> mezzenger::Reliable for Transport<T, Codec, Incoming, Outgoing>
366where
367 T: AsRef<EventTarget> + PostMessage,
368 Codec: kodec::Codec,
369{
370}
371
372impl<T, Codec, Incoming, Outgoing> mezzenger::Order for Transport<T, Codec, Incoming, Outgoing>
373where
374 T: AsRef<EventTarget> + PostMessage,
375 Codec: kodec::Codec,
376{
377}