jacquard_common/
websocket.rs

1//! WebSocket client abstraction
2
3use crate::CowStr;
4use crate::stream::StreamError;
5use bytes::Bytes;
6use n0_future::Stream;
7use std::borrow::Borrow;
8use std::fmt::{self, Display};
9use std::future::Future;
10use std::ops::Deref;
11use std::pin::Pin;
12use url::Url;
13
14/// UTF-8 validated bytes for WebSocket text messages
15#[repr(transparent)]
16#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
17pub struct WsText(Bytes);
18
19impl WsText {
20    /// Create from static string
21    pub const fn from_static(s: &'static str) -> Self {
22        Self(Bytes::from_static(s.as_bytes()))
23    }
24
25    /// Get as string slice
26    pub fn as_str(&self) -> &str {
27        unsafe { std::str::from_utf8_unchecked(&self.0) }
28    }
29
30    /// Create from bytes without validation (caller must ensure UTF-8)
31    ///
32    /// # Safety
33    /// Bytes must be valid UTF-8
34    pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self {
35        Self(bytes)
36    }
37
38    /// Convert into underlying bytes
39    pub fn into_bytes(self) -> Bytes {
40        self.0
41    }
42}
43
44impl Deref for WsText {
45    type Target = str;
46    fn deref(&self) -> &str {
47        self.as_str()
48    }
49}
50
51impl AsRef<str> for WsText {
52    fn as_ref(&self) -> &str {
53        self.as_str()
54    }
55}
56
57impl AsRef<[u8]> for WsText {
58    fn as_ref(&self) -> &[u8] {
59        &self.0
60    }
61}
62
63impl AsRef<Bytes> for WsText {
64    fn as_ref(&self) -> &Bytes {
65        &self.0
66    }
67}
68
69impl Borrow<str> for WsText {
70    fn borrow(&self) -> &str {
71        self.as_str()
72    }
73}
74
75impl Display for WsText {
76    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
77        Display::fmt(self.as_str(), f)
78    }
79}
80
81impl From<String> for WsText {
82    fn from(s: String) -> Self {
83        Self(Bytes::from(s))
84    }
85}
86
87impl From<&str> for WsText {
88    fn from(s: &str) -> Self {
89        Self(Bytes::copy_from_slice(s.as_bytes()))
90    }
91}
92
93impl From<&String> for WsText {
94    fn from(s: &String) -> Self {
95        Self::from(s.as_str())
96    }
97}
98
99impl TryFrom<Bytes> for WsText {
100    type Error = std::str::Utf8Error;
101    fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
102        std::str::from_utf8(&bytes)?;
103        Ok(Self(bytes))
104    }
105}
106
107impl TryFrom<Vec<u8>> for WsText {
108    type Error = std::str::Utf8Error;
109    fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> {
110        Self::try_from(Bytes::from(vec))
111    }
112}
113
114impl From<WsText> for Bytes {
115    fn from(t: WsText) -> Bytes {
116        t.0
117    }
118}
119
120impl Default for WsText {
121    fn default() -> Self {
122        Self(Bytes::new())
123    }
124}
125
126/// WebSocket close code
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
128#[repr(u16)]
129pub enum CloseCode {
130    /// Normal closure
131    Normal = 1000,
132    /// Endpoint going away
133    Away = 1001,
134    /// Protocol error
135    Protocol = 1002,
136    /// Unsupported data
137    Unsupported = 1003,
138    /// Invalid frame payload data
139    Invalid = 1007,
140    /// Policy violation
141    Policy = 1008,
142    /// Message too big
143    Size = 1009,
144    /// Extension negotiation failure
145    Extension = 1010,
146    /// Unexpected condition
147    Error = 1011,
148    /// TLS handshake failure
149    Tls = 1015,
150    /// Other code
151    Other(u16),
152}
153
154impl From<u16> for CloseCode {
155    fn from(code: u16) -> Self {
156        match code {
157            1000 => CloseCode::Normal,
158            1001 => CloseCode::Away,
159            1002 => CloseCode::Protocol,
160            1003 => CloseCode::Unsupported,
161            1007 => CloseCode::Invalid,
162            1008 => CloseCode::Policy,
163            1009 => CloseCode::Size,
164            1010 => CloseCode::Extension,
165            1011 => CloseCode::Error,
166            1015 => CloseCode::Tls,
167            other => CloseCode::Other(other),
168        }
169    }
170}
171
172impl From<CloseCode> for u16 {
173    fn from(code: CloseCode) -> u16 {
174        match code {
175            CloseCode::Normal => 1000,
176            CloseCode::Away => 1001,
177            CloseCode::Protocol => 1002,
178            CloseCode::Unsupported => 1003,
179            CloseCode::Invalid => 1007,
180            CloseCode::Policy => 1008,
181            CloseCode::Size => 1009,
182            CloseCode::Extension => 1010,
183            CloseCode::Error => 1011,
184            CloseCode::Tls => 1015,
185            CloseCode::Other(code) => code,
186        }
187    }
188}
189
190/// WebSocket close frame
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub struct CloseFrame<'a> {
193    /// Close code
194    pub code: CloseCode,
195    /// Close reason text
196    pub reason: CowStr<'a>,
197}
198
199impl<'a> CloseFrame<'a> {
200    /// Create a new close frame
201    pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self {
202        Self {
203            code,
204            reason: reason.into(),
205        }
206    }
207}
208
209/// WebSocket message
210#[derive(Debug, Clone, PartialEq, Eq)]
211pub enum WsMessage {
212    /// Text message (UTF-8)
213    Text(WsText),
214    /// Binary message
215    Binary(Bytes),
216    /// Close frame
217    Close(Option<CloseFrame<'static>>),
218}
219
220impl WsMessage {
221    /// Check if this is a text message
222    pub fn is_text(&self) -> bool {
223        matches!(self, WsMessage::Text(_))
224    }
225
226    /// Check if this is a binary message
227    pub fn is_binary(&self) -> bool {
228        matches!(self, WsMessage::Binary(_))
229    }
230
231    /// Check if this is a close message
232    pub fn is_close(&self) -> bool {
233        matches!(self, WsMessage::Close(_))
234    }
235
236    /// Get as text, if this is a text message
237    pub fn as_text(&self) -> Option<&str> {
238        match self {
239            WsMessage::Text(t) => Some(t.as_str()),
240            _ => None,
241        }
242    }
243
244    /// Get as bytes
245    pub fn as_bytes(&self) -> Option<&[u8]> {
246        match self {
247            WsMessage::Text(t) => Some(t.as_ref()),
248            WsMessage::Binary(b) => Some(b),
249            WsMessage::Close(_) => None,
250        }
251    }
252}
253
254impl From<WsText> for WsMessage {
255    fn from(text: WsText) -> Self {
256        WsMessage::Text(text)
257    }
258}
259
260impl From<String> for WsMessage {
261    fn from(s: String) -> Self {
262        WsMessage::Text(WsText::from(s))
263    }
264}
265
266impl From<&str> for WsMessage {
267    fn from(s: &str) -> Self {
268        WsMessage::Text(WsText::from(s))
269    }
270}
271
272impl From<Bytes> for WsMessage {
273    fn from(bytes: Bytes) -> Self {
274        WsMessage::Binary(bytes)
275    }
276}
277
278impl From<Vec<u8>> for WsMessage {
279    fn from(vec: Vec<u8>) -> Self {
280        WsMessage::Binary(Bytes::from(vec))
281    }
282}
283
284/// WebSocket message stream
285#[cfg(not(target_arch = "wasm32"))]
286pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>>);
287
288/// WebSocket message stream
289#[cfg(target_arch = "wasm32")]
290pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>>);
291
292impl WsStream {
293    /// Create a new message stream
294    #[cfg(not(target_arch = "wasm32"))]
295    pub fn new<S>(stream: S) -> Self
296    where
297        S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static,
298    {
299        Self(Box::pin(stream))
300    }
301
302    /// Create a new message stream
303    #[cfg(target_arch = "wasm32")]
304    pub fn new<S>(stream: S) -> Self
305    where
306        S: Stream<Item = Result<WsMessage, StreamError>> + 'static,
307    {
308        Self(Box::pin(stream))
309    }
310
311    /// Convert into the inner pinned boxed stream
312    #[cfg(not(target_arch = "wasm32"))]
313    pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>> {
314        self.0
315    }
316
317    /// Convert into the inner pinned boxed stream
318    #[cfg(target_arch = "wasm32")]
319    pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>> {
320        self.0
321    }
322
323    /// Split this stream into two streams that both receive all messages
324    ///
325    /// Messages are cloned (cheaply via Bytes rc). Spawns a forwarder task.
326    /// Both returned streams will receive all messages from the original stream.
327    /// The forwarder continues as long as at least one stream is alive.
328    /// If the underlying stream errors, both teed streams will end.
329    pub fn tee(self) -> (WsStream, WsStream) {
330        use futures::channel::mpsc;
331        use n0_future::StreamExt as _;
332
333        let (tx1, rx1) = mpsc::unbounded();
334        let (tx2, rx2) = mpsc::unbounded();
335
336        n0_future::task::spawn(async move {
337            let mut stream = self.0;
338            while let Some(result) = stream.next().await {
339                match result {
340                    Ok(msg) => {
341                        // Clone message (cheap - Bytes is rc'd)
342                        let msg2 = msg.clone();
343
344                        // Send to both channels, continue if at least one succeeds
345                        let send1 = tx1.unbounded_send(Ok(msg));
346                        let send2 = tx2.unbounded_send(Ok(msg2));
347
348                        // Only stop if both channels are closed
349                        if send1.is_err() && send2.is_err() {
350                            break;
351                        }
352                    }
353                    Err(_e) => {
354                        // Underlying stream errored, stop forwarding.
355                        // Both channels will close, ending both streams.
356                        break;
357                    }
358                }
359            }
360        });
361
362        (WsStream::new(rx1), WsStream::new(rx2))
363    }
364}
365
366impl fmt::Debug for WsStream {
367    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368        f.debug_struct("WsStream").finish_non_exhaustive()
369    }
370}
371
372/// WebSocket message sink
373#[cfg(not(target_arch = "wasm32"))]
374pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>>);
375
376/// WebSocket message sink
377#[cfg(target_arch = "wasm32")]
378pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>>);
379
380impl WsSink {
381    /// Create a new message sink
382    #[cfg(not(target_arch = "wasm32"))]
383    pub fn new<S>(sink: S) -> Self
384    where
385        S: n0_future::Sink<WsMessage, Error = StreamError> + Send + 'static,
386    {
387        Self(Box::pin(sink))
388    }
389
390    /// Create a new message sink
391    #[cfg(target_arch = "wasm32")]
392    pub fn new<S>(sink: S) -> Self
393    where
394        S: n0_future::Sink<WsMessage, Error = StreamError> + 'static,
395    {
396        Self(Box::pin(sink))
397    }
398
399    /// Convert into the inner boxed sink
400    #[cfg(not(target_arch = "wasm32"))]
401    pub fn into_inner(
402        self,
403    ) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
404        self.0
405    }
406
407    /// Convert into the inner boxed sink
408    #[cfg(target_arch = "wasm32")]
409    pub fn into_inner(self) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>> {
410        self.0
411    }
412
413    /// get a mutable reference to the inner boxed sink
414    #[cfg(not(target_arch = "wasm32"))]
415    pub fn get_mut(
416        &mut self,
417    ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
418        use std::borrow::BorrowMut;
419
420        self.0.borrow_mut()
421    }
422
423    /// get a mutable reference to the inner boxed sink
424    #[cfg(target_arch = "wasm32")]
425    pub fn get_mut(
426        &mut self,
427    ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + 'static>> {
428        use std::borrow::BorrowMut;
429
430        self.0.borrow_mut()
431    }
432}
433
434impl fmt::Debug for WsSink {
435    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436        f.debug_struct("WsSink").finish_non_exhaustive()
437    }
438}
439
440/// WebSocket client trait
441#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
442pub trait WebSocketClient: Sync {
443    /// Error type for WebSocket operations
444    type Error: std::error::Error + Send + Sync + 'static;
445
446    /// Connect to a WebSocket endpoint
447    fn connect(&self, url: Url) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>;
448
449    /// Connect to a WebSocket endpoint with custom headers
450    ///
451    /// Default implementation ignores headers and calls `connect()`.
452    /// Override this method to support authentication headers for subscriptions.
453    fn connect_with_headers(
454        &self,
455        url: Url,
456        _headers: Vec<(CowStr<'_>, CowStr<'_>)>,
457    ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> {
458        async move { self.connect(url).await }
459    }
460}
461
462/// WebSocket connection with bidirectional streams
463pub struct WebSocketConnection {
464    tx: WsSink,
465    rx: WsStream,
466}
467
468impl WebSocketConnection {
469    /// Create a new WebSocket connection
470    pub fn new(tx: WsSink, rx: WsStream) -> Self {
471        Self { tx, rx }
472    }
473
474    /// Get mutable access to the sender
475    pub fn sender_mut(&mut self) -> &mut WsSink {
476        &mut self.tx
477    }
478
479    /// Get mutable access to the receiver
480    pub fn receiver_mut(&mut self) -> &mut WsStream {
481        &mut self.rx
482    }
483
484    /// Get a reference to the receiver
485    pub fn receiver(&self) -> &WsStream {
486        &self.rx
487    }
488
489    /// Get a reference to the sender
490    pub fn sender(&self) -> &WsSink {
491        &self.tx
492    }
493
494    /// Split into sender and receiver
495    pub fn split(self) -> (WsSink, WsStream) {
496        (self.tx, self.rx)
497    }
498
499    /// Check if connection is open (always true for this abstraction)
500    pub fn is_open(&self) -> bool {
501        true
502    }
503}
504
505impl fmt::Debug for WebSocketConnection {
506    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507        f.debug_struct("WebSocketConnection")
508            .finish_non_exhaustive()
509    }
510}
511
512/// Concrete WebSocket client implementation using tokio-tungstenite-wasm
513pub mod tungstenite_client {
514    use super::*;
515    use crate::IntoStatic;
516    use futures::{SinkExt, StreamExt};
517
518    /// WebSocket client backed by tokio-tungstenite-wasm
519    #[derive(Debug, Clone, Default)]
520    pub struct TungsteniteClient;
521
522    impl TungsteniteClient {
523        /// Create a new tungstenite WebSocket client
524        pub fn new() -> Self {
525            Self
526        }
527    }
528
529    impl WebSocketClient for TungsteniteClient {
530        type Error = tokio_tungstenite_wasm::Error;
531
532        async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> {
533            let ws_stream = tokio_tungstenite_wasm::connect(url.as_str()).await?;
534
535            let (sink, stream) = ws_stream.split();
536
537            // Convert tungstenite messages to our WsMessage
538            let rx_stream = stream.filter_map(|result| async move {
539                match result {
540                    Ok(msg) => match convert_message(msg) {
541                        Some(ws_msg) => Some(Ok(ws_msg)),
542                        None => None, // Skip ping/pong
543                    },
544                    Err(e) => Some(Err(StreamError::transport(e))),
545                }
546            });
547
548            let rx = WsStream::new(rx_stream);
549
550            // Convert our WsMessage to tungstenite messages
551            let tx_sink = sink.with(|msg: WsMessage| async move {
552                Ok::<_, tokio_tungstenite_wasm::Error>(msg.into())
553            });
554
555            let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e));
556            let tx = WsSink::new(tx_sink_mapped);
557
558            Ok(WebSocketConnection::new(tx, rx))
559        }
560    }
561
562    /// Convert tokio-tungstenite-wasm Message to our WsMessage
563    /// Returns None for Ping/Pong which we auto-handle
564    fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> {
565        use tokio_tungstenite_wasm::Message;
566
567        match msg {
568            Message::Text(vec) => {
569                // tokio-tungstenite-wasm Text contains Vec<u8> (UTF-8 validated)
570                let bytes = Bytes::from(vec);
571                Some(WsMessage::Text(unsafe {
572                    WsText::from_bytes_unchecked(bytes)
573                }))
574            }
575            Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))),
576            Message::Close(frame) => {
577                let close_frame = frame.map(|f| {
578                    let code = convert_close_code(f.code);
579                    CloseFrame::new(code, CowStr::from(f.reason.into_owned()))
580                });
581                Some(WsMessage::Close(close_frame))
582            }
583        }
584    }
585
586    /// Convert tokio-tungstenite-wasm CloseCode to our CloseCode
587    fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode {
588        use tokio_tungstenite_wasm::CloseCode as TungsteniteCode;
589
590        match code {
591            TungsteniteCode::Normal => CloseCode::Normal,
592            TungsteniteCode::Away => CloseCode::Away,
593            TungsteniteCode::Protocol => CloseCode::Protocol,
594            TungsteniteCode::Unsupported => CloseCode::Unsupported,
595            TungsteniteCode::Invalid => CloseCode::Invalid,
596            TungsteniteCode::Policy => CloseCode::Policy,
597            TungsteniteCode::Size => CloseCode::Size,
598            TungsteniteCode::Extension => CloseCode::Extension,
599            TungsteniteCode::Error => CloseCode::Error,
600            TungsteniteCode::Tls => CloseCode::Tls,
601            // For other variants, extract raw code
602            other => {
603                let raw: u16 = other.into();
604                CloseCode::from(raw)
605            }
606        }
607    }
608
609    impl From<WsMessage> for tokio_tungstenite_wasm::Message {
610        fn from(msg: WsMessage) -> Self {
611            use tokio_tungstenite_wasm::Message;
612
613            match msg {
614                WsMessage::Text(text) => {
615                    // tokio-tungstenite-wasm Text expects String
616                    let bytes = text.into_bytes();
617                    // Safe: WsText is already UTF-8 validated
618                    let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) };
619                    Message::Text(string)
620                }
621                WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()),
622                WsMessage::Close(frame) => {
623                    let close_frame = frame.map(|f| {
624                        let code = u16::from(f.code).into();
625                        tokio_tungstenite_wasm::CloseFrame {
626                            code,
627                            reason: f.reason.into_static().to_string().into(),
628                        }
629                    });
630                    Message::Close(close_frame)
631                }
632            }
633        }
634    }
635}
636
637#[cfg(test)]
638mod tests {
639    use super::*;
640
641    #[test]
642    fn ws_text_from_string() {
643        let text = WsText::from("hello");
644        assert_eq!(text.as_str(), "hello");
645    }
646
647    #[test]
648    fn ws_text_deref() {
649        let text = WsText::from(String::from("world"));
650        assert_eq!(&*text, "world");
651    }
652
653    #[test]
654    fn ws_text_try_from_bytes() {
655        let bytes = Bytes::from("test");
656        let text = WsText::try_from(bytes).unwrap();
657        assert_eq!(text.as_str(), "test");
658    }
659
660    #[test]
661    fn ws_text_invalid_utf8() {
662        let bytes = Bytes::from(vec![0xFF, 0xFE]);
663        assert!(WsText::try_from(bytes).is_err());
664    }
665
666    #[test]
667    fn ws_message_text() {
668        let msg = WsMessage::from("hello");
669        assert!(msg.is_text());
670        assert_eq!(msg.as_text(), Some("hello"));
671    }
672
673    #[test]
674    fn ws_message_binary() {
675        let msg = WsMessage::from(vec![1, 2, 3]);
676        assert!(msg.is_binary());
677        assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..]));
678    }
679
680    #[test]
681    fn close_code_conversion() {
682        assert_eq!(u16::from(CloseCode::Normal), 1000);
683        assert_eq!(CloseCode::from(1000), CloseCode::Normal);
684        assert_eq!(CloseCode::from(9999), CloseCode::Other(9999));
685    }
686
687    #[test]
688    fn websocket_connection_has_tx_and_rx() {
689        use futures::sink::SinkExt;
690        use futures::stream;
691
692        let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]);
693        let rx = WsStream::new(rx_stream);
694
695        let drain_sink = futures::sink::drain()
696            .sink_map_err(|_: std::convert::Infallible| StreamError::closed());
697        let tx = WsSink::new(drain_sink);
698
699        let conn = WebSocketConnection::new(tx, rx);
700        assert!(conn.is_open());
701    }
702}