rapace_transport_websocket/
lib.rs

1//! rapace-transport-websocket: WebSocket transport for rapace.
2//!
3//! For browser clients or WebSocket-based infrastructure.
4
5use rapace_core::{
6    DecodeError, EncodeCtx, EncodeError, Frame, FrameView, INLINE_PAYLOAD_SIZE,
7    INLINE_PAYLOAD_SLOT, MsgDescHot, Transport, TransportError,
8};
9
10mod shared {
11    use super::*;
12
13    /// Size of MsgDescHot in bytes (must be 64).
14    pub const DESC_SIZE: usize = 64;
15    const _: () = assert!(std::mem::size_of::<MsgDescHot>() == DESC_SIZE);
16
17    /// Internal storage for a received frame.
18    #[derive(Default)]
19    pub struct ReceivedFrame {
20        pub desc: MsgDescHot,
21        pub payload: Vec<u8>,
22    }
23
24    /// Convert MsgDescHot to raw bytes.
25    pub fn desc_to_bytes(desc: &MsgDescHot) -> [u8; DESC_SIZE] {
26        // SAFETY: MsgDescHot is repr(C), Copy, and exactly 64 bytes.
27        unsafe { std::mem::transmute_copy(desc) }
28    }
29
30    /// Convert raw bytes to MsgDescHot.
31    pub fn bytes_to_desc(bytes: &[u8; DESC_SIZE]) -> MsgDescHot {
32        // SAFETY: Same as desc_to_bytes.
33        unsafe { std::mem::transmute_copy(bytes) }
34    }
35
36    /// Encoder for WebSocket transport.
37    ///
38    /// Simply accumulates bytes into a Vec.
39    pub struct WebSocketEncoder {
40        desc: MsgDescHot,
41        payload: Vec<u8>,
42    }
43
44    impl Default for WebSocketEncoder {
45        fn default() -> Self {
46            Self::new()
47        }
48    }
49
50    impl WebSocketEncoder {
51        pub fn new() -> Self {
52            Self {
53                desc: MsgDescHot::new(),
54                payload: Vec::new(),
55            }
56        }
57    }
58
59    impl EncodeCtx for WebSocketEncoder {
60        fn encode_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
61            self.payload.extend_from_slice(bytes);
62            Ok(())
63        }
64
65        fn finish(self: Box<Self>) -> Result<Frame, EncodeError> {
66            Ok(Frame::with_payload(self.desc, self.payload))
67        }
68    }
69
70    /// Decoder for WebSocket transport.
71    pub struct WebSocketDecoder<'a> {
72        data: &'a [u8],
73        pos: usize,
74    }
75
76    impl<'a> WebSocketDecoder<'a> {
77        /// Create a new decoder from a byte slice.
78        pub fn new(data: &'a [u8]) -> Self {
79            Self { data, pos: 0 }
80        }
81    }
82
83    impl<'a> rapace_core::DecodeCtx<'a> for WebSocketDecoder<'a> {
84        fn decode_bytes(&mut self) -> Result<&'a [u8], DecodeError> {
85            let result = &self.data[self.pos..];
86            self.pos = self.data.len();
87            Ok(result)
88        }
89
90        fn remaining(&self) -> &'a [u8] {
91            &self.data[self.pos..]
92        }
93    }
94
95    pub use WebSocketDecoder as Decoder;
96    pub use WebSocketEncoder as Encoder;
97    pub use {bytes_to_desc as to_desc, desc_to_bytes as to_bytes};
98}
99
100pub use shared::{Decoder as WebSocketDecoder, Encoder as WebSocketEncoder};
101
102#[cfg(not(target_arch = "wasm32"))]
103mod native {
104    use super::shared::{DESC_SIZE, ReceivedFrame, to_bytes, to_desc};
105    use super::*;
106    use futures::stream::{SplitSink, SplitStream};
107    use futures::{SinkExt, StreamExt};
108    use parking_lot::Mutex as SyncMutex;
109    use std::sync::Arc;
110    use std::sync::atomic::{AtomicBool, Ordering};
111    use tokio::io::{AsyncRead, AsyncWrite};
112    use tokio::sync::Mutex as AsyncMutex;
113    use tokio_tungstenite::WebSocketStream;
114    use tokio_tungstenite::tungstenite::Message;
115
116    /// WebSocket-based transport implementation.
117    ///
118    /// Works with any WebSocket stream (TCP, TLS, etc.).
119    pub struct WebSocketTransport<S> {
120        inner: Arc<WebSocketInner<S>>,
121    }
122
123    struct WebSocketInner<S> {
124        /// Write half of the WebSocket (async mutex for holding across awaits).
125        sink: AsyncMutex<SplitSink<WebSocketStream<S>, Message>>,
126        /// Read half of the WebSocket (async mutex for holding across awaits).
127        stream: AsyncMutex<SplitStream<WebSocketStream<S>>>,
128        /// Buffer for the last received frame (for FrameView lifetime).
129        last_frame: SyncMutex<Option<ReceivedFrame>>,
130        /// Whether the transport is closed.
131        closed: AtomicBool,
132    }
133
134    impl<S> WebSocketTransport<S>
135    where
136        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
137    {
138        /// Create a new WebSocket transport wrapping the given WebSocket stream.
139        pub fn new(ws: WebSocketStream<S>) -> Self {
140            let (sink, stream) = ws.split();
141            Self {
142                inner: Arc::new(WebSocketInner {
143                    sink: AsyncMutex::new(sink),
144                    stream: AsyncMutex::new(stream),
145                    last_frame: SyncMutex::new(None),
146                    closed: AtomicBool::new(false),
147                }),
148            }
149        }
150
151        /// Check if the transport is closed.
152        pub fn is_closed(&self) -> bool {
153            self.inner.closed.load(Ordering::Acquire)
154        }
155    }
156
157    impl WebSocketTransport<tokio::io::DuplexStream> {
158        /// Create a connected pair of WebSocket transports for testing.
159        ///
160        /// Uses `tokio::io::duplex` with WebSocket framing internally.
161        pub async fn pair() -> (Self, Self) {
162            // 64KB buffer should be plenty for testing
163            let (client_stream, server_stream) = tokio::io::duplex(65536);
164
165            // Wrap both ends with WebSocket framing.
166            // We use the client/server handshake over the duplex streams.
167            let (ws_a, ws_b) = tokio::join!(
168                async {
169                    tokio_tungstenite::client_async("ws://localhost/", client_stream)
170                        .await
171                        .expect("client handshake failed")
172                        .0
173                },
174                async {
175                    tokio_tungstenite::accept_async(server_stream)
176                        .await
177                        .expect("server handshake failed")
178                }
179            );
180
181            (Self::new(ws_a), Self::new(ws_b))
182        }
183    }
184
185    impl<S> Transport for WebSocketTransport<S>
186    where
187        S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
188    {
189        async fn send_frame(&self, frame: &Frame) -> Result<(), TransportError> {
190            if self.is_closed() {
191                return Err(TransportError::Closed);
192            }
193
194            let payload = frame.payload();
195
196            // Build message: descriptor + payload
197            let mut data = Vec::with_capacity(DESC_SIZE + payload.len());
198            data.extend_from_slice(&to_bytes(&frame.desc));
199            data.extend_from_slice(payload);
200
201            // Send as binary WebSocket message
202            let mut sink = self.inner.sink.lock().await;
203            sink.send(Message::Binary(data.into())).await.map_err(|e| {
204                TransportError::Io(std::io::Error::other(format!("websocket send: {}", e)))
205            })?;
206
207            Ok(())
208        }
209
210        async fn recv_frame(&self) -> Result<FrameView<'_>, TransportError> {
211            if self.is_closed() {
212                return Err(TransportError::Closed);
213            }
214
215            let mut stream = self.inner.stream.lock().await;
216
217            // Read next message
218            loop {
219                let msg = stream
220                    .next()
221                    .await
222                    .ok_or(TransportError::Closed)?
223                    .map_err(|e| {
224                        TransportError::Io(std::io::Error::other(format!("websocket recv: {}", e)))
225                    })?;
226
227                match msg {
228                    Message::Binary(data) => {
229                        // Validate minimum length
230                        if data.len() < DESC_SIZE {
231                            return Err(TransportError::Io(std::io::Error::new(
232                                std::io::ErrorKind::InvalidData,
233                                format!("frame too small: {} < {}", data.len(), DESC_SIZE),
234                            )));
235                        }
236
237                        // Parse descriptor
238                        let desc_bytes: [u8; DESC_SIZE] = data[..DESC_SIZE].try_into().unwrap();
239                        let mut desc = to_desc(&desc_bytes);
240
241                        // Extract payload
242                        let payload = data[DESC_SIZE..].to_vec();
243                        let payload_len = payload.len();
244
245                        // Drop stream lock before storing frame
246                        drop(stream);
247
248                        // Update desc.payload_len to match actual received payload
249                        desc.payload_len = payload_len as u32;
250
251                        // If payload fits inline, mark it as inline
252                        if payload_len <= INLINE_PAYLOAD_SIZE {
253                            desc.payload_slot = INLINE_PAYLOAD_SLOT;
254                            desc.inline_payload[..payload_len].copy_from_slice(&payload);
255                        } else {
256                            // Mark as external payload
257                            desc.payload_slot = 0;
258                        }
259
260                        // Store frame for FrameView lifetime
261                        {
262                            let mut last = self.inner.last_frame.lock();
263                            *last = Some(ReceivedFrame { desc, payload });
264                        }
265
266                        // Create FrameView from stored frame
267                        let last = self.inner.last_frame.lock();
268                        let frame_ref = last.as_ref().unwrap();
269
270                        let desc_ptr = &frame_ref.desc as *const MsgDescHot;
271                        let payload_slice = if frame_ref.desc.is_inline() {
272                            frame_ref.desc.inline_payload()
273                        } else {
274                            &frame_ref.payload
275                        };
276                        let payload_ptr = payload_slice.as_ptr();
277                        let payload_len = payload_slice.len();
278
279                        // SAFETY: Extending lifetime is safe because:
280                        // - Data lives in Arc<WebSocketInner> which outlives &self
281                        // - FrameView borrows &self, preventing concurrent recv_frame
282                        let desc: &MsgDescHot = unsafe { &*desc_ptr };
283                        let payload: &[u8] =
284                            unsafe { std::slice::from_raw_parts(payload_ptr, payload_len) };
285
286                        return Ok(FrameView::new(desc, payload));
287                    }
288                    Message::Close(_) => {
289                        self.inner.closed.store(true, Ordering::Release);
290                        return Err(TransportError::Closed);
291                    }
292                    Message::Ping(_) | Message::Pong(_) | Message::Text(_) | Message::Frame(_) => {
293                        // Ignore ping/pong/text frames, continue reading
294                        continue;
295                    }
296                }
297            }
298        }
299
300        fn encoder(&self) -> Box<dyn EncodeCtx + '_> {
301            Box::new(WebSocketEncoder::new())
302        }
303
304        async fn close(&self) -> Result<(), TransportError> {
305            self.inner.closed.store(true, Ordering::Release);
306
307            // Send WebSocket close frame
308            let mut sink = self.inner.sink.lock().await;
309            let _ = sink.send(Message::Close(None)).await;
310
311            Ok(())
312        }
313    }
314
315    #[cfg(test)]
316    mod tests {
317        use super::*;
318        use rapace_core::FrameFlags;
319
320        #[tokio::test]
321        async fn test_pair_creation() {
322            let (a, b) = WebSocketTransport::pair().await;
323            assert!(!a.is_closed());
324            assert!(!b.is_closed());
325        }
326
327        #[tokio::test]
328        async fn test_send_recv_inline() {
329            let (a, b) = WebSocketTransport::pair().await;
330
331            // Create a frame with inline payload
332            let mut desc = MsgDescHot::new();
333            desc.msg_id = 1;
334            desc.channel_id = 1;
335            desc.method_id = 42;
336            desc.flags = FrameFlags::DATA;
337
338            let frame = Frame::with_inline_payload(desc, b"hello").unwrap();
339
340            // Send from A
341            a.send_frame(&frame).await.unwrap();
342
343            // Receive on B
344            let view = b.recv_frame().await.unwrap();
345            assert_eq!(view.desc.msg_id, 1);
346            assert_eq!(view.desc.channel_id, 1);
347            assert_eq!(view.desc.method_id, 42);
348            assert_eq!(view.payload, b"hello");
349        }
350
351        #[tokio::test]
352        async fn test_send_recv_external_payload() {
353            let (a, b) = WebSocketTransport::pair().await;
354
355            let mut desc = MsgDescHot::new();
356            desc.msg_id = 2;
357            desc.flags = FrameFlags::DATA;
358
359            let payload = vec![0u8; 1000]; // Larger than inline
360            let frame = Frame::with_payload(desc, payload.clone());
361
362            a.send_frame(&frame).await.unwrap();
363
364            let view = b.recv_frame().await.unwrap();
365            assert_eq!(view.desc.msg_id, 2);
366            assert_eq!(view.payload.len(), 1000);
367        }
368
369        #[tokio::test]
370        async fn test_bidirectional() {
371            let (a, b) = WebSocketTransport::pair().await;
372
373            // A -> B
374            let mut desc_a = MsgDescHot::new();
375            desc_a.msg_id = 1;
376            let frame_a = Frame::with_inline_payload(desc_a, b"from A").unwrap();
377            a.send_frame(&frame_a).await.unwrap();
378
379            // B -> A
380            let mut desc_b = MsgDescHot::new();
381            desc_b.msg_id = 2;
382            let frame_b = Frame::with_inline_payload(desc_b, b"from B").unwrap();
383            b.send_frame(&frame_b).await.unwrap();
384
385            // Receive both
386            let view_b = b.recv_frame().await.unwrap();
387            assert_eq!(view_b.payload, b"from A");
388
389            let view_a = a.recv_frame().await.unwrap();
390            assert_eq!(view_a.payload, b"from B");
391        }
392
393        #[tokio::test]
394        async fn test_close() {
395            let (a, _b) = WebSocketTransport::pair().await;
396
397            a.close().await.unwrap();
398            assert!(a.is_closed());
399
400            // Sending on closed transport should fail
401            let frame = Frame::new(MsgDescHot::new());
402            assert!(matches!(
403                a.send_frame(&frame).await,
404                Err(TransportError::Closed)
405            ));
406        }
407
408        #[tokio::test]
409        async fn test_encoder() {
410            let (a, _b) = WebSocketTransport::pair().await;
411
412            let mut encoder = a.encoder();
413            encoder.encode_bytes(b"test data").unwrap();
414            let frame = encoder.finish().unwrap();
415
416            assert_eq!(frame.payload(), b"test data");
417        }
418    }
419
420    /// Conformance tests using rapace-testkit.
421    #[cfg(test)]
422    mod conformance_tests {
423        use super::*;
424        use rapace_testkit::{TestError, TransportFactory};
425
426        struct WebSocketFactory;
427
428        impl TransportFactory for WebSocketFactory {
429            type Transport = WebSocketTransport<tokio::io::DuplexStream>;
430
431            async fn connect_pair() -> Result<(Self::Transport, Self::Transport), TestError> {
432                Ok(WebSocketTransport::pair().await)
433            }
434        }
435
436        #[tokio::test]
437        async fn unary_happy_path() {
438            rapace_testkit::run_unary_happy_path::<WebSocketFactory>().await;
439        }
440
441        #[tokio::test]
442        async fn unary_multiple_calls() {
443            rapace_testkit::run_unary_multiple_calls::<WebSocketFactory>().await;
444        }
445
446        #[tokio::test]
447        async fn ping_pong() {
448            rapace_testkit::run_ping_pong::<WebSocketFactory>().await;
449        }
450
451        #[tokio::test]
452        async fn deadline_success() {
453            rapace_testkit::run_deadline_success::<WebSocketFactory>().await;
454        }
455
456        #[tokio::test]
457        async fn deadline_exceeded() {
458            rapace_testkit::run_deadline_exceeded::<WebSocketFactory>().await;
459        }
460
461        #[tokio::test]
462        async fn cancellation() {
463            rapace_testkit::run_cancellation::<WebSocketFactory>().await;
464        }
465
466        #[tokio::test]
467        async fn credit_grant() {
468            rapace_testkit::run_credit_grant::<WebSocketFactory>().await;
469        }
470
471        #[tokio::test]
472        async fn error_response() {
473            rapace_testkit::run_error_response::<WebSocketFactory>().await;
474        }
475
476        // Session-level tests (semantic enforcement)
477
478        #[tokio::test]
479        async fn session_credit_exhaustion() {
480            rapace_testkit::run_session_credit_exhaustion::<WebSocketFactory>().await;
481        }
482
483        #[tokio::test]
484        async fn session_cancelled_channel_drop() {
485            rapace_testkit::run_session_cancelled_channel_drop::<WebSocketFactory>().await;
486        }
487
488        #[tokio::test]
489        async fn session_cancel_control_frame() {
490            rapace_testkit::run_session_cancel_control_frame::<WebSocketFactory>().await;
491        }
492
493        #[tokio::test]
494        async fn session_grant_credits_control_frame() {
495            rapace_testkit::run_session_grant_credits_control_frame::<WebSocketFactory>().await;
496        }
497
498        #[tokio::test]
499        async fn session_deadline_check() {
500            rapace_testkit::run_session_deadline_check::<WebSocketFactory>().await;
501        }
502
503        // Streaming tests
504
505        #[tokio::test]
506        async fn server_streaming_happy_path() {
507            rapace_testkit::run_server_streaming_happy_path::<WebSocketFactory>().await;
508        }
509
510        #[tokio::test]
511        async fn client_streaming_happy_path() {
512            rapace_testkit::run_client_streaming_happy_path::<WebSocketFactory>().await;
513        }
514
515        #[tokio::test]
516        async fn bidirectional_streaming() {
517            rapace_testkit::run_bidirectional_streaming::<WebSocketFactory>().await;
518        }
519
520        #[tokio::test]
521        async fn streaming_cancellation() {
522            rapace_testkit::run_streaming_cancellation::<WebSocketFactory>().await;
523        }
524
525        // Macro-generated streaming tests
526
527        #[tokio::test]
528        async fn macro_server_streaming() {
529            rapace_testkit::run_macro_server_streaming::<WebSocketFactory>().await;
530        }
531    }
532}
533
534#[cfg(not(target_arch = "wasm32"))]
535pub use native::WebSocketTransport;
536
537#[cfg(target_arch = "wasm32")]
538mod wasm {
539    use super::shared::{DESC_SIZE, ReceivedFrame, to_bytes, to_desc};
540    use super::*;
541    use gloo_timers::future::TimeoutFuture;
542    use parking_lot::Mutex as SyncMutex;
543    use std::cell::{Cell, RefCell};
544    use std::collections::VecDeque;
545    use std::future::Future;
546    use std::pin::Pin;
547    use std::rc::Rc;
548    use std::sync::Arc;
549    use std::sync::atomic::{AtomicBool, Ordering};
550    use std::task::{Context, Poll};
551    use wasm_bindgen::JsCast;
552    use wasm_bindgen::prelude::*;
553    use web_sys::{BinaryType, CloseEvent, ErrorEvent, MessageEvent, WebSocket};
554
555    /// WebSocket transport implementation for browser environments.
556    pub struct WebSocketTransport {
557        inner: Arc<WebSocketInner>,
558    }
559
560    struct WebSocketInner {
561        ws: WasmWebSocket,
562        last_frame: SyncMutex<Option<ReceivedFrame>>,
563        closed: AtomicBool,
564    }
565
566    impl WebSocketTransport {
567        /// Connect to a WebSocket server at the given URL.
568        pub async fn connect(url: &str) -> Result<Self, TransportError> {
569            let ws = WasmWebSocket::connect(url).await?;
570            Ok(Self {
571                inner: Arc::new(WebSocketInner {
572                    ws,
573                    last_frame: SyncMutex::new(None),
574                    closed: AtomicBool::new(false),
575                }),
576            })
577        }
578
579        fn is_closed(&self) -> bool {
580            self.inner.closed.load(Ordering::Acquire)
581        }
582    }
583
584    impl Transport for WebSocketTransport {
585        async fn send_frame(&self, frame: &Frame) -> Result<(), TransportError> {
586            if self.is_closed() {
587                return Err(TransportError::Closed);
588            }
589
590            let payload = frame.payload();
591            let mut data = Vec::with_capacity(DESC_SIZE + payload.len());
592            data.extend_from_slice(&to_bytes(&frame.desc));
593            data.extend_from_slice(payload);
594
595            self.inner.ws.send(&data)?;
596            Ok(())
597        }
598
599        async fn recv_frame(&self) -> Result<FrameView<'_>, TransportError> {
600            if self.is_closed() {
601                return Err(TransportError::Closed);
602            }
603
604            let data = self.inner.ws.recv().await?;
605
606            if data.len() < DESC_SIZE {
607                return Err(TransportError::Io(std::io::Error::new(
608                    std::io::ErrorKind::InvalidData,
609                    format!("frame too small: {} < {}", data.len(), DESC_SIZE),
610                )));
611            }
612
613            let desc_bytes: [u8; DESC_SIZE] = data[..DESC_SIZE].try_into().unwrap();
614            let mut desc = to_desc(&desc_bytes);
615
616            let payload = data[DESC_SIZE..].to_vec();
617            let payload_len = payload.len();
618            desc.payload_len = payload_len as u32;
619
620            if payload_len <= INLINE_PAYLOAD_SIZE {
621                desc.payload_slot = INLINE_PAYLOAD_SLOT;
622                desc.inline_payload[..payload_len].copy_from_slice(&payload);
623            } else {
624                desc.payload_slot = 0;
625            }
626
627            {
628                let mut last = self.inner.last_frame.lock();
629                *last = Some(ReceivedFrame { desc, payload });
630            }
631
632            let last = self.inner.last_frame.lock();
633            let frame_ref = last.as_ref().unwrap();
634
635            let desc_ptr = &frame_ref.desc as *const MsgDescHot;
636            let payload_slice = if frame_ref.desc.is_inline() {
637                frame_ref.desc.inline_payload()
638            } else {
639                &frame_ref.payload
640            };
641            let payload_ptr = payload_slice.as_ptr();
642            let payload_len = payload_slice.len();
643
644            // SAFETY: Data lives inside Arc<WebSocketInner>.
645            let desc: &MsgDescHot = unsafe { &*desc_ptr };
646            let payload: &[u8] = unsafe { std::slice::from_raw_parts(payload_ptr, payload_len) };
647
648            Ok(FrameView::new(desc, payload))
649        }
650
651        fn encoder(&self) -> Box<dyn EncodeCtx + '_> {
652            Box::new(WebSocketEncoder::new())
653        }
654
655        async fn close(&self) -> Result<(), TransportError> {
656            self.inner.closed.store(true, Ordering::Release);
657            self.inner.ws.close();
658            Ok(())
659        }
660    }
661
662    /// A wasm-compatible WebSocket wrapper.
663    struct WasmWebSocket {
664        ws: WebSocket,
665        received: Rc<RefCell<VecDeque<Vec<u8>>>>,
666        error: Rc<RefCell<Option<String>>>,
667        closed: Rc<Cell<bool>>,
668    }
669
670    unsafe impl Send for WasmWebSocket {}
671    unsafe impl Sync for WasmWebSocket {}
672
673    impl WasmWebSocket {
674        async fn connect(url: &str) -> Result<Self, TransportError> {
675            let ws = WebSocket::new(url).map_err(js_error_from_value)?;
676            ws.set_binary_type(BinaryType::Arraybuffer);
677
678            let received = Rc::new(RefCell::new(VecDeque::new()));
679            let error: Rc<RefCell<Option<String>>> = Rc::new(RefCell::new(None));
680            let closed = Rc::new(Cell::new(false));
681
682            let open_result: Rc<RefCell<Option<Result<(), String>>>> = Rc::new(RefCell::new(None));
683
684            {
685                let open_result_clone = Rc::clone(&open_result);
686                let onopen = Closure::<dyn FnMut()>::once(move || {
687                    *open_result_clone.borrow_mut() = Some(Ok(()));
688                });
689                ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
690                onopen.forget();
691            }
692
693            {
694                let open_result_clone = Rc::clone(&open_result);
695                let onerror = Closure::<dyn FnMut(ErrorEvent)>::once(move |e: ErrorEvent| {
696                    let msg = e.message();
697                    let err_msg = if msg.is_empty() {
698                        "WebSocket connection failed".to_string()
699                    } else {
700                        msg
701                    };
702                    *open_result_clone.borrow_mut() = Some(Err(err_msg));
703                });
704                ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
705                onerror.forget();
706            }
707
708            loop {
709                if let Some(result) = open_result.borrow_mut().take() {
710                    match result {
711                        Ok(()) => break,
712                        Err(msg) => return Err(js_error_from_msg(msg)),
713                    }
714                }
715                SendTimeoutFuture::new(10).await;
716            }
717
718            {
719                let received = Rc::clone(&received);
720                let onmessage = Closure::<dyn FnMut(MessageEvent)>::new(move |e: MessageEvent| {
721                    if let Ok(abuf) = e.data().dyn_into::<js_sys::ArrayBuffer>() {
722                        let array = js_sys::Uint8Array::new(&abuf);
723                        received.borrow_mut().push_back(array.to_vec());
724                    }
725                });
726                ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
727                onmessage.forget();
728            }
729
730            {
731                let error = Rc::clone(&error);
732                let onerror = Closure::<dyn FnMut(ErrorEvent)>::new(move |e: ErrorEvent| {
733                    *error.borrow_mut() = Some(format!("WebSocket error: {}", e.message()));
734                });
735                ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
736                onerror.forget();
737            }
738
739            {
740                let closed_clone = Rc::clone(&closed);
741                let onclose = Closure::<dyn FnMut(CloseEvent)>::new(move |_e: CloseEvent| {
742                    closed_clone.set(true);
743                });
744                ws.set_onclose(Some(onclose.as_ref().unchecked_ref()));
745                onclose.forget();
746            }
747
748            Ok(Self {
749                ws,
750                received,
751                error,
752                closed,
753            })
754        }
755
756        fn send(&self, data: &[u8]) -> Result<(), TransportError> {
757            if self.closed.get() {
758                return Err(TransportError::Closed);
759            }
760
761            if let Some(err) = self.error.borrow().as_ref() {
762                return Err(js_error_from_msg(err.clone()));
763            }
764
765            self.ws
766                .send_with_u8_array(data)
767                .map_err(js_error_from_value)
768        }
769
770        async fn recv(&self) -> Result<Vec<u8>, TransportError> {
771            loop {
772                if let Some(err) = self.error.borrow().as_ref() {
773                    return Err(js_error_from_msg(err.clone()));
774                }
775
776                if let Some(data) = self.received.borrow_mut().pop_front() {
777                    return Ok(data);
778                }
779
780                if self.closed.get() {
781                    return Err(TransportError::Closed);
782                }
783
784                SendTimeoutFuture::new(1).await;
785            }
786        }
787
788        fn close(&self) {
789            let _ = self.ws.close();
790            self.closed.set(true);
791        }
792    }
793
794    struct SendTimeoutFuture {
795        inner: TimeoutFuture,
796    }
797
798    impl SendTimeoutFuture {
799        fn new(ms: u32) -> Self {
800            Self {
801                inner: TimeoutFuture::new(ms),
802            }
803        }
804    }
805
806    impl Future for SendTimeoutFuture {
807        type Output = ();
808
809        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
810            Pin::new(&mut self.inner).poll(cx)
811        }
812    }
813
814    unsafe impl Send for SendTimeoutFuture {}
815
816    fn js_error_from_value(err: JsValue) -> TransportError {
817        let msg = if let Some(s) = err.as_string() {
818            s
819        } else if let Ok(js_string) = js_sys::JSON::stringify(&err) {
820            js_string.as_string().unwrap_or_else(|| format!("{err:?}"))
821        } else {
822            format!("{err:?}")
823        };
824        TransportError::Io(std::io::Error::other(msg))
825    }
826
827    fn js_error_from_msg<S: Into<String>>(msg: S) -> TransportError {
828        TransportError::Io(std::io::Error::other(msg.into()))
829    }
830}
831
832#[cfg(target_arch = "wasm32")]
833pub use wasm::WebSocketTransport;