Skip to main content

jacquard_common/
websocket.rs

1//! WebSocket client abstraction
2
3use crate::CowStr;
4use crate::deps::fluent_uri::Uri;
5use crate::stream::StreamError;
6use alloc::boxed::Box;
7use alloc::string::String;
8use alloc::string::ToString;
9use alloc::vec::Vec;
10use bytes::Bytes;
11use core::borrow::Borrow;
12use core::fmt::{self, Display};
13use core::future::Future;
14use core::ops::Deref;
15use core::pin::Pin;
16use n0_future::Stream;
17
18/// UTF-8 validated bytes for WebSocket text messages
19#[repr(transparent)]
20#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
21pub struct WsText(Bytes);
22
23impl WsText {
24    /// Create from static string
25    pub const fn from_static(s: &'static str) -> Self {
26        Self(Bytes::from_static(s.as_bytes()))
27    }
28
29    /// Get as string slice
30    pub fn as_str(&self) -> &str {
31        unsafe { core::str::from_utf8_unchecked(&self.0) }
32    }
33
34    /// Create from bytes without validation (caller must ensure UTF-8)
35    ///
36    /// # Safety
37    /// Bytes must be valid UTF-8
38    pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self {
39        Self(bytes)
40    }
41
42    /// Convert into underlying bytes
43    pub fn into_bytes(self) -> Bytes {
44        self.0
45    }
46}
47
48impl Deref for WsText {
49    type Target = str;
50    fn deref(&self) -> &str {
51        self.as_str()
52    }
53}
54
55impl AsRef<str> for WsText {
56    fn as_ref(&self) -> &str {
57        self.as_str()
58    }
59}
60
61impl AsRef<[u8]> for WsText {
62    fn as_ref(&self) -> &[u8] {
63        &self.0
64    }
65}
66
67impl AsRef<Bytes> for WsText {
68    fn as_ref(&self) -> &Bytes {
69        &self.0
70    }
71}
72
73impl Borrow<str> for WsText {
74    fn borrow(&self) -> &str {
75        self.as_str()
76    }
77}
78
79impl Display for WsText {
80    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81        Display::fmt(self.as_str(), f)
82    }
83}
84
85impl From<String> for WsText {
86    fn from(s: String) -> Self {
87        Self(Bytes::from(s))
88    }
89}
90
91impl From<&str> for WsText {
92    fn from(s: &str) -> Self {
93        Self(Bytes::copy_from_slice(s.as_bytes()))
94    }
95}
96
97impl From<&String> for WsText {
98    fn from(s: &String) -> Self {
99        Self::from(s.as_str())
100    }
101}
102
103impl TryFrom<Bytes> for WsText {
104    type Error = core::str::Utf8Error;
105    fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
106        core::str::from_utf8(&bytes)?;
107        Ok(Self(bytes))
108    }
109}
110
111impl TryFrom<Vec<u8>> for WsText {
112    type Error = core::str::Utf8Error;
113    fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> {
114        Self::try_from(Bytes::from(vec))
115    }
116}
117
118impl From<WsText> for Bytes {
119    fn from(t: WsText) -> Bytes {
120        t.0
121    }
122}
123
124impl Default for WsText {
125    fn default() -> Self {
126        Self(Bytes::new())
127    }
128}
129
130/// WebSocket close code
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132#[repr(u16)]
133pub enum CloseCode {
134    /// Normal closure
135    Normal = 1000,
136    /// Endpoint going away
137    Away = 1001,
138    /// Protocol error
139    Protocol = 1002,
140    /// Unsupported data
141    Unsupported = 1003,
142    /// Invalid frame payload data
143    Invalid = 1007,
144    /// Policy violation
145    Policy = 1008,
146    /// Message too big
147    Size = 1009,
148    /// Extension negotiation failure
149    Extension = 1010,
150    /// Unexpected condition
151    Error = 1011,
152    /// TLS handshake failure
153    Tls = 1015,
154    /// Other code
155    Other(u16),
156}
157
158impl From<u16> for CloseCode {
159    fn from(code: u16) -> Self {
160        match code {
161            1000 => CloseCode::Normal,
162            1001 => CloseCode::Away,
163            1002 => CloseCode::Protocol,
164            1003 => CloseCode::Unsupported,
165            1007 => CloseCode::Invalid,
166            1008 => CloseCode::Policy,
167            1009 => CloseCode::Size,
168            1010 => CloseCode::Extension,
169            1011 => CloseCode::Error,
170            1015 => CloseCode::Tls,
171            other => CloseCode::Other(other),
172        }
173    }
174}
175
176impl From<CloseCode> for u16 {
177    fn from(code: CloseCode) -> u16 {
178        match code {
179            CloseCode::Normal => 1000,
180            CloseCode::Away => 1001,
181            CloseCode::Protocol => 1002,
182            CloseCode::Unsupported => 1003,
183            CloseCode::Invalid => 1007,
184            CloseCode::Policy => 1008,
185            CloseCode::Size => 1009,
186            CloseCode::Extension => 1010,
187            CloseCode::Error => 1011,
188            CloseCode::Tls => 1015,
189            CloseCode::Other(code) => code,
190        }
191    }
192}
193
194/// WebSocket close frame
195#[derive(Debug, Clone, PartialEq, Eq)]
196pub struct CloseFrame<'a> {
197    /// Close code
198    pub code: CloseCode,
199    /// Close reason text
200    pub reason: CowStr<'a>,
201}
202
203impl<'a> CloseFrame<'a> {
204    /// Create a new close frame
205    pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self {
206        Self {
207            code,
208            reason: reason.into(),
209        }
210    }
211}
212
213/// WebSocket message
214#[derive(Debug, Clone, PartialEq, Eq)]
215pub enum WsMessage {
216    /// Text message (UTF-8)
217    Text(WsText),
218    /// Binary message
219    Binary(Bytes),
220    /// Close frame
221    Close(Option<CloseFrame<'static>>),
222}
223
224impl WsMessage {
225    /// Check if this is a text message
226    pub fn is_text(&self) -> bool {
227        matches!(self, WsMessage::Text(_))
228    }
229
230    /// Check if this is a binary message
231    pub fn is_binary(&self) -> bool {
232        matches!(self, WsMessage::Binary(_))
233    }
234
235    /// Check if this is a close message
236    pub fn is_close(&self) -> bool {
237        matches!(self, WsMessage::Close(_))
238    }
239
240    /// Get as text, if this is a text message
241    pub fn as_text(&self) -> Option<&str> {
242        match self {
243            WsMessage::Text(t) => Some(t.as_str()),
244            _ => None,
245        }
246    }
247
248    /// Get as bytes
249    pub fn as_bytes(&self) -> Option<&[u8]> {
250        match self {
251            WsMessage::Text(t) => Some(t.as_ref()),
252            WsMessage::Binary(b) => Some(b),
253            WsMessage::Close(_) => None,
254        }
255    }
256}
257
258impl From<WsText> for WsMessage {
259    fn from(text: WsText) -> Self {
260        WsMessage::Text(text)
261    }
262}
263
264impl From<String> for WsMessage {
265    fn from(s: String) -> Self {
266        WsMessage::Text(WsText::from(s))
267    }
268}
269
270impl From<&str> for WsMessage {
271    fn from(s: &str) -> Self {
272        WsMessage::Text(WsText::from(s))
273    }
274}
275
276impl From<Bytes> for WsMessage {
277    fn from(bytes: Bytes) -> Self {
278        WsMessage::Binary(bytes)
279    }
280}
281
282impl From<Vec<u8>> for WsMessage {
283    fn from(vec: Vec<u8>) -> Self {
284        WsMessage::Binary(Bytes::from(vec))
285    }
286}
287
288/// WebSocket message stream
289#[cfg(not(target_arch = "wasm32"))]
290pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>>);
291
292/// WebSocket message stream
293#[cfg(target_arch = "wasm32")]
294pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>>);
295
296impl WsStream {
297    /// Create a new message stream
298    #[cfg(not(target_arch = "wasm32"))]
299    pub fn new<S>(stream: S) -> Self
300    where
301        S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static,
302    {
303        Self(Box::pin(stream))
304    }
305
306    /// Create a new message stream
307    #[cfg(target_arch = "wasm32")]
308    pub fn new<S>(stream: S) -> Self
309    where
310        S: Stream<Item = Result<WsMessage, StreamError>> + 'static,
311    {
312        Self(Box::pin(stream))
313    }
314
315    /// Convert into the inner pinned boxed stream
316    #[cfg(not(target_arch = "wasm32"))]
317    pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>> {
318        self.0
319    }
320
321    /// Convert into the inner pinned boxed stream
322    #[cfg(target_arch = "wasm32")]
323    pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>> {
324        self.0
325    }
326
327    /// Split this stream into two streams that both receive all messages
328    ///
329    /// Messages are cloned (cheaply via Bytes rc). Spawns a forwarder task.
330    /// Both returned streams will receive all messages from the original stream.
331    /// The forwarder continues as long as at least one stream is alive.
332    /// If the underlying stream errors, both teed streams will end.
333    pub fn tee(self) -> (WsStream, WsStream) {
334        use futures::channel::mpsc;
335        use n0_future::StreamExt as _;
336
337        let (tx1, rx1) = mpsc::unbounded();
338        let (tx2, rx2) = mpsc::unbounded();
339
340        n0_future::task::spawn(async move {
341            let mut stream = self.0;
342            while let Some(result) = stream.next().await {
343                match result {
344                    Ok(msg) => {
345                        // Clone message (cheap - Bytes is rc'd)
346                        let msg2 = msg.clone();
347
348                        // Send to both channels, continue if at least one succeeds
349                        let send1 = tx1.unbounded_send(Ok(msg));
350                        let send2 = tx2.unbounded_send(Ok(msg2));
351
352                        // Only stop if both channels are closed
353                        if send1.is_err() && send2.is_err() {
354                            break;
355                        }
356                    }
357                    Err(_e) => {
358                        // Underlying stream errored, stop forwarding.
359                        // Both channels will close, ending both streams.
360                        break;
361                    }
362                }
363            }
364        });
365
366        (WsStream::new(rx1), WsStream::new(rx2))
367    }
368}
369
370impl fmt::Debug for WsStream {
371    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372        f.debug_struct("WsStream").finish_non_exhaustive()
373    }
374}
375
376/// WebSocket message sink
377#[cfg(not(target_arch = "wasm32"))]
378pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>>);
379
380/// WebSocket message sink
381#[cfg(target_arch = "wasm32")]
382pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>>);
383
384impl WsSink {
385    /// Create a new message sink
386    #[cfg(not(target_arch = "wasm32"))]
387    pub fn new<S>(sink: S) -> Self
388    where
389        S: n0_future::Sink<WsMessage, Error = StreamError> + Send + 'static,
390    {
391        Self(Box::pin(sink))
392    }
393
394    /// Create a new message sink
395    #[cfg(target_arch = "wasm32")]
396    pub fn new<S>(sink: S) -> Self
397    where
398        S: n0_future::Sink<WsMessage, Error = StreamError> + 'static,
399    {
400        Self(Box::pin(sink))
401    }
402
403    /// Convert into the inner boxed sink
404    #[cfg(not(target_arch = "wasm32"))]
405    pub fn into_inner(
406        self,
407    ) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
408        self.0
409    }
410
411    /// Convert into the inner boxed sink
412    #[cfg(target_arch = "wasm32")]
413    pub fn into_inner(self) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>> {
414        self.0
415    }
416
417    /// get a mutable reference to the inner boxed sink
418    #[cfg(not(target_arch = "wasm32"))]
419    pub fn get_mut(
420        &mut self,
421    ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
422        use core::borrow::BorrowMut;
423
424        self.0.borrow_mut()
425    }
426
427    /// get a mutable reference to the inner boxed sink
428    #[cfg(target_arch = "wasm32")]
429    pub fn get_mut(
430        &mut self,
431    ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + 'static>> {
432        use core::borrow::BorrowMut;
433
434        self.0.borrow_mut()
435    }
436}
437
438impl fmt::Debug for WsSink {
439    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440        f.debug_struct("WsSink").finish_non_exhaustive()
441    }
442}
443
444/// WebSocket client trait
445#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
446pub trait WebSocketClient: Sync {
447    /// Error type for WebSocket operations
448    type Error: core::error::Error + Send + Sync + 'static;
449
450    /// Connect to a WebSocket endpoint
451    fn connect(
452        &self,
453        uri: Uri<&str>,
454    ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>;
455
456    /// Connect to a WebSocket endpoint with custom headers
457    ///
458    /// Default implementation ignores headers and calls `connect()`.
459    /// Override this method to support authentication headers for subscriptions.
460    fn connect_with_headers(
461        &self,
462        uri: Uri<&str>,
463        _headers: Vec<(CowStr<'_>, CowStr<'_>)>,
464    ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> {
465        async move { self.connect(uri).await }
466    }
467}
468
469/// WebSocket connection with bidirectional streams
470pub struct WebSocketConnection {
471    tx: WsSink,
472    rx: WsStream,
473}
474
475impl WebSocketConnection {
476    /// Create a new WebSocket connection
477    pub fn new(tx: WsSink, rx: WsStream) -> Self {
478        Self { tx, rx }
479    }
480
481    /// Get mutable access to the sender
482    pub fn sender_mut(&mut self) -> &mut WsSink {
483        &mut self.tx
484    }
485
486    /// Get mutable access to the receiver
487    pub fn receiver_mut(&mut self) -> &mut WsStream {
488        &mut self.rx
489    }
490
491    /// Get a reference to the receiver
492    pub fn receiver(&self) -> &WsStream {
493        &self.rx
494    }
495
496    /// Get a reference to the sender
497    pub fn sender(&self) -> &WsSink {
498        &self.tx
499    }
500
501    /// Split into sender and receiver
502    pub fn split(self) -> (WsSink, WsStream) {
503        (self.tx, self.rx)
504    }
505
506    /// Check if connection is open (always true for this abstraction)
507    pub fn is_open(&self) -> bool {
508        true
509    }
510}
511
512impl fmt::Debug for WebSocketConnection {
513    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514        f.debug_struct("WebSocketConnection")
515            .finish_non_exhaustive()
516    }
517}
518
519/// Concrete WebSocket client implementation using tokio-tungstenite-wasm
520pub mod tungstenite_client {
521    use super::*;
522    use crate::IntoStatic;
523    use futures::{SinkExt, StreamExt};
524
525    /// WebSocket client backed by tokio-tungstenite-wasm
526    #[derive(Debug, Clone, Default)]
527    pub struct TungsteniteClient;
528
529    impl TungsteniteClient {
530        /// Create a new tungstenite WebSocket client
531        pub fn new() -> Self {
532            Self
533        }
534    }
535
536    impl WebSocketClient for TungsteniteClient {
537        type Error = tokio_tungstenite_wasm::Error;
538
539        async fn connect(&self, uri: Uri<&str>) -> Result<WebSocketConnection, Self::Error> {
540            let ws_stream = tokio_tungstenite_wasm::connect(uri.as_str()).await?;
541
542            let (sink, stream) = ws_stream.split();
543
544            // Convert tungstenite messages to our WsMessage
545            let rx_stream = stream.filter_map(|result| async move {
546                match result {
547                    Ok(msg) => match convert_message(msg) {
548                        Some(ws_msg) => Some(Ok(ws_msg)),
549                        None => None, // Skip ping/pong
550                    },
551                    Err(e) => Some(Err(StreamError::transport(e))),
552                }
553            });
554
555            let rx = WsStream::new(rx_stream);
556
557            // Convert our WsMessage to tungstenite messages
558            let tx_sink = sink.with(|msg: WsMessage| async move {
559                Ok::<_, tokio_tungstenite_wasm::Error>(msg.into())
560            });
561
562            let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e));
563            let tx = WsSink::new(tx_sink_mapped);
564
565            Ok(WebSocketConnection::new(tx, rx))
566        }
567    }
568
569    /// Convert tokio-tungstenite-wasm Message to our WsMessage
570    /// Returns None for Ping/Pong which we auto-handle
571    fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> {
572        use tokio_tungstenite_wasm::Message;
573
574        match msg {
575            Message::Text(vec) => {
576                // tokio-tungstenite-wasm Text contains Vec<u8> (UTF-8 validated)
577                let bytes = Bytes::from(vec);
578                Some(WsMessage::Text(unsafe {
579                    WsText::from_bytes_unchecked(bytes)
580                }))
581            }
582            Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))),
583            Message::Close(frame) => {
584                let close_frame = frame.map(|f| {
585                    let code = convert_close_code(f.code);
586                    CloseFrame::new(code, CowStr::from(f.reason.into_owned()))
587                });
588                Some(WsMessage::Close(close_frame))
589            }
590        }
591    }
592
593    /// Convert tokio-tungstenite-wasm CloseCode to our CloseCode
594    fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode {
595        use tokio_tungstenite_wasm::CloseCode as TungsteniteCode;
596
597        match code {
598            TungsteniteCode::Normal => CloseCode::Normal,
599            TungsteniteCode::Away => CloseCode::Away,
600            TungsteniteCode::Protocol => CloseCode::Protocol,
601            TungsteniteCode::Unsupported => CloseCode::Unsupported,
602            TungsteniteCode::Invalid => CloseCode::Invalid,
603            TungsteniteCode::Policy => CloseCode::Policy,
604            TungsteniteCode::Size => CloseCode::Size,
605            TungsteniteCode::Extension => CloseCode::Extension,
606            TungsteniteCode::Error => CloseCode::Error,
607            TungsteniteCode::Tls => CloseCode::Tls,
608            // For other variants, extract raw code
609            other => {
610                let raw: u16 = other.into();
611                CloseCode::from(raw)
612            }
613        }
614    }
615
616    impl From<WsMessage> for tokio_tungstenite_wasm::Message {
617        fn from(msg: WsMessage) -> Self {
618            use tokio_tungstenite_wasm::Message;
619
620            match msg {
621                WsMessage::Text(text) => {
622                    // tokio-tungstenite-wasm Text expects String
623                    let bytes = text.into_bytes();
624                    // Safe: WsText is already UTF-8 validated
625                    let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) };
626                    Message::Text(string)
627                }
628                WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()),
629                WsMessage::Close(frame) => {
630                    let close_frame = frame.map(|f| {
631                        let code = u16::from(f.code).into();
632                        tokio_tungstenite_wasm::CloseFrame {
633                            code,
634                            reason: f.reason.into_static().to_string().into(),
635                        }
636                    });
637                    Message::Close(close_frame)
638                }
639            }
640        }
641    }
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647
648    #[test]
649    fn ws_text_from_string() {
650        let text = WsText::from("hello");
651        assert_eq!(text.as_str(), "hello");
652    }
653
654    #[test]
655    fn ws_text_deref() {
656        let text = WsText::from(String::from("world"));
657        assert_eq!(&*text, "world");
658    }
659
660    #[test]
661    fn ws_text_try_from_bytes() {
662        let bytes = Bytes::from("test");
663        let text = WsText::try_from(bytes).unwrap();
664        assert_eq!(text.as_str(), "test");
665    }
666
667    #[test]
668    fn ws_text_invalid_utf8() {
669        let bytes = Bytes::from(vec![0xFF, 0xFE]);
670        assert!(WsText::try_from(bytes).is_err());
671    }
672
673    #[test]
674    fn ws_message_text() {
675        let msg = WsMessage::from("hello");
676        assert!(msg.is_text());
677        assert_eq!(msg.as_text(), Some("hello"));
678    }
679
680    #[test]
681    fn ws_message_binary() {
682        let msg = WsMessage::from(vec![1, 2, 3]);
683        assert!(msg.is_binary());
684        assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..]));
685    }
686
687    #[test]
688    fn close_code_conversion() {
689        assert_eq!(u16::from(CloseCode::Normal), 1000);
690        assert_eq!(CloseCode::from(1000), CloseCode::Normal);
691        assert_eq!(CloseCode::from(9999), CloseCode::Other(9999));
692    }
693
694    #[test]
695    fn websocket_connection_has_tx_and_rx() {
696        use futures::sink::SinkExt;
697        use futures::stream;
698
699        let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]);
700        let rx = WsStream::new(rx_stream);
701
702        let drain_sink = futures::sink::drain()
703            .sink_map_err(|_: std::convert::Infallible| StreamError::closed());
704        let tx = WsSink::new(drain_sink);
705
706        let conn = WebSocketConnection::new(tx, rx);
707        assert!(conn.is_open());
708    }
709}