Skip to main content

dioxus_fullstack/payloads/
websocket.rs

1#![allow(unreachable_code)]
2#![allow(unused_imports)]
3
4//! This module implements WebSocket support for Dioxus Fullstack applications.
5//!
6//! WebSockets provide a full-duplex communication channel over a single, long-lived connection.
7//!
8//! This makes them ideal for real-time applications where the server and the client need to communicate
9//! frequently and with low latency. Unlike Server-Sent Events (SSE), WebSockets allow the direct
10//! transport of binary data, enabling things like video and audio streaming as well as more efficient
11//! zero-copy serialization formats.
12//!
13//! This module implements a variety of types:
14//! - `Websocket<In, Out, E>`: Represents a WebSocket connection that can send messages of type `In` and receive messages of type `Out`, using the encoding `E`.
15//! - `UseWebsocket<In, Out, E>`: A hook that provides a reactive interface to a WebSocket connection.
16//! - `WebSocketOptions`: Configuration options for establishing a WebSocket connection.
17//! - `TypedWebsocket<In, Out, E>`: A typed wrapper around an Axum WebSocket connection for server-side use.
18//! - `WebsocketState`: An enum representing the state of the WebSocket connection.
19//! - plus a variety of error types and traits for encoding/decoding messages.
20//!
21//! Dioxus Fullstack websockets are typed in both directions, letting the happy path (`.send()` and `.recv()`)
22//! automatically serialize and deserialize messages for you.
23
24use crate::{ClientRequest, Encoding, FromResponse, IntoRequest, JsonEncoding, ServerFnError};
25use axum::{
26    extract::{FromRequest, Request},
27    http::StatusCode,
28};
29use axum_core::response::{IntoResponse, Response};
30use bytes::Bytes;
31use dioxus_core::{use_hook, CapturedError, Result};
32use dioxus_fullstack_core::{HttpError, RequestError};
33use dioxus_hooks::{use_resource, Resource, UseWaker};
34use dioxus_hooks::{use_signal, use_waker};
35use dioxus_signals::{ReadSignal, ReadableExt, ReadableOptionExt, Signal, WritableExt};
36use futures::{
37    stream::{SplitSink, SplitStream},
38    Sink, SinkExt, Stream, StreamExt, TryFutureExt,
39};
40use serde::{de::DeserializeOwned, Serialize};
41use std::{
42    marker::PhantomData,
43    pin::Pin,
44    prelude::rust_2024::Future,
45    rc::Rc,
46    task::{ready, Context, Poll},
47};
48
49#[cfg(feature = "web")]
50use {
51    futures_util::lock::Mutex,
52    gloo_net::websocket::{futures::WebSocket as WsWebsocket, Message as WsMessage},
53};
54
55/// A hook that provides a reactive interface to a WebSocket connection.
56///
57/// WebSockets provide a full-duplex communication channel over a single, long-lived connection.
58///
59/// This makes them ideal for real-time applications where the server and the client need to communicate
60/// frequently and with low latency. Unlike Server-Sent Events (SSE), WebSockets allow the direct
61/// transport of binary data, enabling things like video and audio streaming as well as more efficient
62/// zero-copy serialization formats.
63///
64/// This hook takes a function that returns a future which resolves to a `Websocket<In, Out, E>` -
65/// usually a server function.
66pub fn use_websocket<
67    In: 'static,
68    Out: 'static,
69    E: Into<CapturedError> + 'static,
70    F: Future<Output = Result<Websocket<In, Out, Enc>, E>> + 'static,
71    Enc: Encoding,
72>(
73    mut connect_to_websocket: impl FnMut() -> F + 'static,
74) -> UseWebsocket<In, Out, Enc> {
75    let mut waker = use_waker();
76    let mut status = use_signal(|| WebsocketState::Connecting);
77    let status_read = use_hook(|| ReadSignal::new(status));
78
79    let connection = use_resource(move || {
80        let connection = connect_to_websocket().map_err(|e| e.into());
81        async move {
82            let connection = connection.await;
83
84            // Update the status based on the result of the connection attempt
85            match connection.as_ref() {
86                Ok(_) => status.set(WebsocketState::Open),
87                Err(_) => status.set(WebsocketState::FailedToConnect),
88            }
89
90            // Wake up the `.recv()` calls waiting for the connection to be established
91            waker.wake(());
92
93            // Wrap in Rc so we can clone it out of the Resource without holding
94            // a borrow guard across await points
95            connection.map(Rc::new)
96        }
97    });
98
99    UseWebsocket {
100        connection,
101        waker,
102        status,
103        status_read,
104    }
105}
106
107/// The return type of the `use_websocket` hook.
108///
109/// See the `use_websocket` documentation for more details.
110///
111/// This handle provides methods to send and receive messages, check the connection status,
112/// and wait for the connection to be established.
113pub struct UseWebsocket<In, Out, Enc = JsonEncoding>
114where
115    In: 'static,
116    Out: 'static,
117    Enc: 'static,
118{
119    #[allow(clippy::type_complexity)]
120    connection: Resource<Result<Rc<Websocket<In, Out, Enc>>, CapturedError>>,
121    waker: UseWaker<()>,
122    status: Signal<WebsocketState>,
123    status_read: ReadSignal<WebsocketState>,
124}
125
126impl<In, Out, E> UseWebsocket<In, Out, E> {
127    /// Wait for the connection to be established. This guarantees that subsequent calls to methods like
128    /// `.try_recv()` will not fail due to the connection not being ready.
129    pub async fn connect(&self) -> WebsocketState {
130        // Wait for the connection to be established
131        while !self.connection.finished() {
132            _ = self.waker.wait().await;
133        }
134
135        self.status.cloned()
136    }
137
138    /// Returns true if the WebSocket is currently connecting.
139    ///
140    /// This can be useful to present a loading state to the user while the connection is being established.
141    pub fn connecting(&self) -> bool {
142        matches!(self.status.cloned(), WebsocketState::Connecting)
143    }
144
145    /// Returns true if the Websocket is closed due to an error.
146    pub fn is_err(&self) -> bool {
147        matches!(self.status.cloned(), WebsocketState::FailedToConnect)
148    }
149
150    /// Returns true if the WebSocket is currently shut down and cannot be used to send or receive messages.
151    pub fn is_closed(&self) -> bool {
152        matches!(
153            self.status.cloned(),
154            WebsocketState::Closed | WebsocketState::FailedToConnect
155        )
156    }
157
158    /// Get the current status of the WebSocket connection.
159    pub fn status(&self) -> ReadSignal<WebsocketState> {
160        self.status_read
161    }
162
163    /// Send a raw message over the WebSocket connection
164    ///
165    /// To send a message with a particular type, see the `.send()` method instead.
166    pub async fn send_raw(&self, msg: Message) -> Result<(), WebsocketError> {
167        self.connect().await;
168        self.get_connection()?.send_raw(msg).await
169    }
170
171    /// Receive a raw message from the WebSocket connection
172    ///
173    /// To receive a message with a particular type, see the `.recv()` method instead.
174    pub async fn recv_raw(&mut self) -> Result<Message, WebsocketError> {
175        self.connect().await;
176        let ws = self.get_connection()?;
177
178        // Race the recv against the waker — if the connection is being recreated
179        // (e.g. a reactive dependency changed), the waker fires and we return an error
180        // so the caller's loop can restart and pick up the new connection.
181        let recv_fut = ws.recv_raw();
182        let waker_fut = self.waker.wait();
183        futures::pin_mut!(recv_fut, waker_fut);
184
185        match futures::future::select(recv_fut, waker_fut).await {
186            futures::future::Either::Left((recv_result, _)) => {
187                if let Err(WebsocketError::ConnectionClosed { .. }) = recv_result.as_ref() {
188                    self.received_shutdown();
189                }
190                recv_result
191            }
192            futures::future::Either::Right(_) => Err(WebsocketError::ConnectionClosed {
193                code: CloseCode::Away,
194                description: "Connection replaced by a new one".to_string(),
195            }),
196        }
197    }
198
199    pub async fn send(&self, msg: In) -> Result<(), WebsocketError>
200    where
201        In: Serialize,
202        E: Encoding,
203    {
204        self.send_raw(Message::Binary(
205            E::to_bytes(&msg).ok_or_else(WebsocketError::serialization)?,
206        ))
207        .await
208    }
209
210    /// Receive the next message from the WebSocket connection, deserialized into the `Out` type.
211    ///
212    /// If the connection is still opening, this will wait until the connection is established.
213    /// If the connection fails to open or is killed while waiting, an error will be returned.
214    ///
215    /// This method returns an error if the connection is closed since we assume closed connections
216    /// are a "failure".
217    pub async fn recv(&mut self) -> Result<Out, WebsocketError>
218    where
219        Out: DeserializeOwned,
220        E: Encoding,
221    {
222        self.connect().await;
223        let ws = self.get_connection()?;
224
225        let recv_fut = ws.recv();
226        let waker_fut = self.waker.wait();
227        futures::pin_mut!(recv_fut, waker_fut);
228
229        match futures::future::select(recv_fut, waker_fut).await {
230            futures::future::Either::Left((recv_result, _)) => {
231                if let Err(WebsocketError::ConnectionClosed { .. }) = recv_result.as_ref() {
232                    self.received_shutdown();
233                }
234                recv_result
235            }
236            futures::future::Either::Right(_) => Err(WebsocketError::ConnectionClosed {
237                code: CloseCode::Away,
238                description: "Connection replaced by a new one".to_string(),
239            }),
240        }
241    }
242
243    /// Set the WebSocket connection.
244    ///
245    /// This method takes a `Result<Websocket<In, Out, E>, Err>`, allowing you to drive the connection
246    /// into an errored state manually.
247    pub fn set<Err: Into<CapturedError>>(&mut self, socket: Result<Websocket<In, Out, E>, Err>) {
248        match socket {
249            Ok(_) => self.status.set(WebsocketState::Open),
250            Err(_) => self.status.set(WebsocketState::FailedToConnect),
251        }
252
253        self.connection
254            .set(Some(socket.map(Rc::new).map_err(|e| e.into())));
255        self.waker.wake(());
256    }
257
258    /// Mark the WebSocket as closed. This is called internally when the connection is closed.
259    fn received_shutdown(&self) {
260        let mut _self = *self;
261        _self.status.set(WebsocketState::Closed);
262        _self.waker.wake(());
263    }
264
265    /// Clone the `Rc<Websocket>` out of the Resource using peek, so we don't hold a borrow
266    /// guard across await points. This prevents AlreadyBorrowed panics when the Resource
267    /// tries to write while recv() is awaiting.
268    #[allow(clippy::result_large_err)]
269    fn get_connection(&self) -> Result<Rc<Websocket<In, Out, E>>, WebsocketError> {
270        self.connection.with_peek(|opt| {
271            opt.as_ref()
272                .ok_or_else(WebsocketError::closed_away)?
273                .as_ref()
274                .map(Rc::clone)
275                .map_err(|_| WebsocketError::AlreadyClosed)
276        })
277    }
278}
279
280impl<In, Out, E> Copy for UseWebsocket<In, Out, E> {}
281impl<In, Out, E> Clone for UseWebsocket<In, Out, E> {
282    fn clone(&self) -> Self {
283        *self
284    }
285}
286
287#[derive(Debug, Clone, PartialEq, Copy)]
288pub enum WebsocketState {
289    /// The WebSocket is connecting.
290    Connecting,
291
292    /// The WebSocket is open and ready to send and receive messages.
293    Open,
294
295    /// The WebSocket is closing.
296    Closing,
297
298    /// The WebSocket is closed and cannot be used to send or receive messages.
299    Closed,
300
301    /// The WebSocket failed to connect
302    FailedToConnect,
303}
304
305/// A WebSocket connection that can send and receive messages of type `In` and `Out`.
306pub struct Websocket<In = String, Out = String, E = JsonEncoding> {
307    protocol: Option<String>,
308
309    #[allow(clippy::type_complexity)]
310    _in: std::marker::PhantomData<fn() -> (In, Out, E)>,
311
312    #[cfg(not(target_arch = "wasm32"))]
313    native: Option<native::SplitSocket>,
314
315    #[cfg(feature = "web")]
316    web: Option<WebsysSocket>,
317
318    response: Option<axum::response::Response>,
319}
320
321impl<I, O, E> Websocket<I, O, E> {
322    pub async fn recv(&self) -> Result<O, WebsocketError>
323    where
324        O: DeserializeOwned,
325        E: Encoding,
326    {
327        loop {
328            let msg = self.recv_raw().await?;
329            match msg {
330                Message::Text(text) => {
331                    let e: O =
332                        E::decode(text.into()).ok_or_else(WebsocketError::deserialization)?;
333                    return Ok(e);
334                }
335                Message::Binary(bytes) => {
336                    let e: O = E::decode(bytes).ok_or_else(WebsocketError::deserialization)?;
337                    return Ok(e);
338                }
339                Message::Close { code, reason } => {
340                    return Err(WebsocketError::ConnectionClosed {
341                        code,
342                        description: reason,
343                    });
344                }
345
346                // todo - are we supposed to response to pings?
347                Message::Ping(_bytes) => continue,
348                Message::Pong(_bytes) => continue,
349            }
350        }
351    }
352
353    /// Send a typed message over the WebSocket connection.
354    ///
355    /// This method serializes the message using the specified encoding `E` before sending it.
356    /// The message will always be sent as a binary message, even if the encoding is valid UTF-8
357    /// like JSON.
358    pub async fn send(&self, msg: I) -> Result<(), WebsocketError>
359    where
360        I: Serialize,
361        E: Encoding,
362    {
363        let bytes = E::to_bytes(&msg).ok_or_else(WebsocketError::serialization)?;
364        self.send_raw(Message::Binary(bytes)).await
365    }
366
367    /// Send a raw message over the WebSocket connection.
368    ///
369    /// This method allows sending text, binary, ping, pong, and close messages directly.
370    pub async fn send_raw(&self, message: Message) -> Result<(), WebsocketError> {
371        #[cfg(feature = "web")]
372        if cfg!(target_arch = "wasm32") {
373            let mut sender = self
374                .web
375                .as_ref()
376                .ok_or_else(|| WebsocketError::Uninitialized)?
377                .sender
378                .lock()
379                .await;
380
381            match message {
382                Message::Text(s) => {
383                    sender.send(gloo_net::websocket::Message::Text(s)).await?;
384                }
385                Message::Binary(bytes) => {
386                    sender
387                        .send(gloo_net::websocket::Message::Bytes(bytes.into()))
388                        .await?;
389                }
390                Message::Close { .. } => {
391                    sender.close().await?;
392                }
393                Message::Ping(_bytes) => return Ok(()),
394                Message::Pong(_bytes) => return Ok(()),
395            }
396
397            return Ok(());
398        }
399
400        #[cfg(not(target_arch = "wasm32"))]
401        {
402            let mut sender = self
403                .native
404                .as_ref()
405                .ok_or_else(|| WebsocketError::Uninitialized)?
406                .sender
407                .lock()
408                .await;
409
410            sender
411                .send(message.into())
412                .await
413                .map_err(WebsocketError::from)?;
414        }
415
416        Ok(())
417    }
418
419    /// Receive a raw message from the WebSocket connection.
420    pub async fn recv_raw(&self) -> Result<Message, WebsocketError> {
421        #[cfg(feature = "web")]
422        if cfg!(target_arch = "wasm32") {
423            let mut conn = self.web.as_ref().unwrap().receiver.lock().await;
424            return match conn.next().await {
425                Some(Ok(WsMessage::Text(text))) => Ok(Message::Text(text)),
426                Some(Ok(WsMessage::Bytes(items))) => Ok(Message::Binary(items.into())),
427                Some(Err(e)) => Err(WebsocketError::from(e)),
428                None => Err(WebsocketError::closed_away()),
429            };
430        }
431
432        #[cfg(not(target_arch = "wasm32"))]
433        {
434            use tungstenite::Message as TMessage;
435            let mut conn = self.native.as_ref().unwrap().receiver.lock().await;
436            return match conn.next().await {
437                Some(Ok(res)) => match res {
438                    TMessage::Text(utf8_bytes) => Ok(Message::Text(utf8_bytes.to_string())),
439                    TMessage::Binary(bytes) => Ok(Message::Binary(bytes)),
440                    TMessage::Close(Some(cf)) => Ok(Message::Close {
441                        code: cf.code.into(),
442                        reason: cf.reason.to_string(),
443                    }),
444                    TMessage::Close(None) => Ok(Message::Close {
445                        code: CloseCode::Away,
446                        reason: "Away".to_string(),
447                    }),
448                    TMessage::Ping(bytes) => Ok(Message::Ping(bytes)),
449                    TMessage::Pong(bytes) => Ok(Message::Pong(bytes)),
450                    TMessage::Frame(_frame) => Err(WebsocketError::Unexpected),
451                },
452                Some(Err(e)) => Err(WebsocketError::from(e)),
453                None => Err(WebsocketError::closed_away()),
454            };
455        }
456
457        unimplemented!("Non web wasm32 clients are not supported yet")
458    }
459
460    pub fn protocol(&self) -> Option<&str> {
461        self.protocol.as_deref()
462    }
463}
464
465// no two websockets are ever equal
466impl<I, O, E> PartialEq for Websocket<I, O, E> {
467    fn eq(&self, _other: &Self) -> bool {
468        false
469    }
470}
471
472// Create a new WebSocket connection that uses the provided function to handle incoming messages
473impl<In, Out, E> IntoResponse for Websocket<In, Out, E> {
474    fn into_response(self) -> Response {
475        let Some(response) = self.response else {
476            return HttpError::new(
477                StatusCode::INTERNAL_SERVER_ERROR,
478                "WebSocket response not initialized",
479            )
480            .into_response();
481        };
482
483        response.into_response()
484    }
485}
486
487impl<I, O, E> FromResponse<UpgradingWebsocket> for Websocket<I, O, E> {
488    fn from_response(res: UpgradingWebsocket) -> impl Future<Output = Result<Self, ServerFnError>> {
489        async move {
490            #[cfg(not(target_arch = "wasm32"))]
491            let native = res.native;
492
493            #[cfg(feature = "web")]
494            let web = res.web.map(|f| {
495                let (sender, receiver) = f.split();
496                WebsysSocket {
497                    sender: Mutex::new(sender),
498                    receiver: Mutex::new(receiver),
499                }
500            });
501
502            Ok(Websocket {
503                protocol: res.protocol,
504                #[cfg(not(target_arch = "wasm32"))]
505                native,
506                #[cfg(feature = "web")]
507                web,
508                response: None,
509                _in: PhantomData,
510            })
511        }
512    }
513}
514
515pub struct WebSocketOptions {
516    protocols: Vec<String>,
517    automatic_reconnect: bool,
518    #[cfg(feature = "server")]
519    upgrade: Option<axum::extract::ws::WebSocketUpgrade>,
520    #[cfg(feature = "server")]
521    on_failed_upgrade: Option<Box<dyn FnOnce(axum::Error) + Send + 'static>>,
522}
523
524impl WebSocketOptions {
525    pub fn new() -> Self {
526        Self {
527            protocols: Vec::new(),
528            automatic_reconnect: false,
529
530            #[cfg(feature = "server")]
531            upgrade: None,
532
533            #[cfg(feature = "server")]
534            on_failed_upgrade: None,
535        }
536    }
537
538    /// Automatically reconnect if the connection is lost. This uses an exponential backoff strategy.
539    pub fn with_automatic_reconnect(mut self) -> Self {
540        self.automatic_reconnect = true;
541        self
542    }
543
544    #[cfg(feature = "server")]
545    pub fn on_failed_upgrade(
546        mut self,
547        callback: impl FnOnce(axum::Error) + Send + 'static,
548    ) -> Self {
549        self.on_failed_upgrade = Some(Box::new(callback));
550
551        self
552    }
553
554    #[cfg(feature = "server")]
555    pub fn on_upgrade<F, Fut, In, Out, Enc>(mut self, callback: F) -> Websocket<In, Out, Enc>
556    where
557        F: FnOnce(TypedWebsocket<In, Out, Enc>) -> Fut + Send + 'static,
558        Fut: Future<Output = ()> + 'static,
559    {
560        let on_failed_upgrade = self.on_failed_upgrade.take();
561        let response = self
562            .upgrade
563            .unwrap()
564            .on_failed_upgrade(|e| {
565                if let Some(callback) = on_failed_upgrade {
566                    callback(e);
567                }
568            })
569            .on_upgrade(|socket| {
570                let res = crate::spawn_platform(move || {
571                    callback(TypedWebsocket {
572                        _in: PhantomData,
573                        _out: PhantomData,
574                        _enc: PhantomData,
575                        inner: socket,
576                    })
577                });
578                async move {
579                    let _ = res.await;
580                }
581            });
582
583        Websocket {
584            // the protocol is none here since it won't be accessible until after the upgrade
585            protocol: None,
586            response: Some(response),
587            _in: PhantomData,
588
589            #[cfg(not(target_arch = "wasm32"))]
590            native: None,
591
592            #[cfg(feature = "web")]
593            web: None,
594        }
595    }
596}
597
598impl Default for WebSocketOptions {
599    fn default() -> Self {
600        Self::new()
601    }
602}
603
604impl IntoRequest<UpgradingWebsocket> for WebSocketOptions {
605    fn into_request(
606        self,
607        request: ClientRequest,
608    ) -> impl Future<Output = std::result::Result<UpgradingWebsocket, RequestError>> + 'static {
609        async move {
610            #[cfg(feature = "web")]
611            if cfg!(target_arch = "wasm32") {
612                let url_path = request.url().path();
613                let url_query = request.url().query();
614                let url_fragment = request.url().fragment();
615                let path_and_query = format!(
616                    "{}{}{}",
617                    url_path,
618                    url_query.map_or("".to_string(), |q| format!("?{q}")),
619                    url_fragment.map_or("".to_string(), |f| format!("#{f}"))
620                );
621
622                let socket = gloo_net::websocket::futures::WebSocket::open_with_protocols(
623                    // ! very important we use the path here and not the full url on web.
624                    // for as long as serverfns are meant to target the same origin, this is fine.
625                    &path_and_query,
626                    &self
627                        .protocols
628                        .iter()
629                        .map(String::as_str)
630                        .collect::<Vec<_>>(),
631                )
632                .map_err(|error| RequestError::Connect(error.to_string()))?;
633
634                return Ok(UpgradingWebsocket {
635                    protocol: Some(socket.protocol()),
636                    web: Some(socket),
637                    #[cfg(not(target_arch = "wasm32"))]
638                    native: None,
639                });
640            }
641
642            #[cfg(not(target_arch = "wasm32"))]
643            {
644                let response = native::send_request(request, &self.protocols).await?;
645
646                let (inner, protocol) = response
647                    .into_stream_and_protocol(self.protocols, None)
648                    .await?;
649
650                return Ok(UpgradingWebsocket {
651                    protocol,
652                    native: Some(inner),
653                    #[cfg(feature = "web")]
654                    web: None,
655                });
656            }
657
658            unimplemented!("Non web wasm32 clients are not supported yet")
659        }
660    }
661}
662
663impl<S: Send> FromRequest<S> for WebSocketOptions {
664    type Rejection = axum::response::Response;
665
666    fn from_request(
667        _req: Request,
668        _: &S,
669    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
670        #[cfg(not(feature = "server"))]
671        return async move { Err(StatusCode::NOT_IMPLEMENTED.into_response()) };
672
673        #[cfg(feature = "server")]
674        async move {
675            let ws = match axum::extract::ws::WebSocketUpgrade::from_request(_req, &()).await {
676                Ok(ws) => ws,
677                Err(rejection) => return Err(rejection.into_response()),
678            };
679
680            Ok(WebSocketOptions {
681                protocols: vec![],
682                automatic_reconnect: false,
683                upgrade: Some(ws),
684                on_failed_upgrade: None,
685            })
686        }
687    }
688}
689
690#[doc(hidden)]
691pub struct UpgradingWebsocket {
692    protocol: Option<String>,
693
694    #[cfg(feature = "web")]
695    web: Option<gloo_net::websocket::futures::WebSocket>,
696
697    #[cfg(not(target_arch = "wasm32"))]
698    native: Option<native::SplitSocket>,
699}
700
701unsafe impl Send for UpgradingWebsocket {}
702unsafe impl Sync for UpgradingWebsocket {}
703
704#[cfg(feature = "server")]
705pub struct TypedWebsocket<In, Out, E = JsonEncoding> {
706    _in: std::marker::PhantomData<fn() -> In>,
707    _out: std::marker::PhantomData<fn() -> Out>,
708    _enc: std::marker::PhantomData<fn() -> E>,
709
710    inner: axum::extract::ws::WebSocket,
711}
712
713#[cfg(feature = "server")]
714impl<In: DeserializeOwned, Out: Serialize, E: Encoding> TypedWebsocket<In, Out, E> {
715    /// Receive an incoming message from the client.
716    pub async fn recv(&mut self) -> Result<In, WebsocketError> {
717        self.next()
718            .await
719            .unwrap_or(Err(WebsocketError::closed_away()))
720    }
721
722    /// Send an outgoing message.
723    pub async fn send(&mut self, msg: Out) -> Result<(), WebsocketError> {
724        SinkExt::send(self, msg).await
725    }
726
727    /// Receive another message.
728    pub async fn recv_raw(&mut self) -> Result<Message, WebsocketError> {
729        use axum::extract::ws::Message as AxumMessage;
730
731        let message = self
732            .inner
733            .next()
734            .await
735            .ok_or_else(WebsocketError::closed_away)?
736            .map_err(|_| WebsocketError::AlreadyClosed)?;
737
738        Ok(match message {
739            AxumMessage::Text(utf8_bytes) => Message::Text(utf8_bytes.to_string()),
740            AxumMessage::Binary(bytes) => Message::Binary(bytes),
741            AxumMessage::Ping(bytes) => Message::Ping(bytes),
742            AxumMessage::Pong(bytes) => Message::Pong(bytes),
743            AxumMessage::Close(close_frame) => Message::Close {
744                code: close_frame
745                    .clone()
746                    .map_or(CloseCode::Away, |cf| cf.code.into()),
747                reason: close_frame.map_or("Away".to_string(), |cf| cf.reason.to_string()),
748            },
749        })
750    }
751
752    /// Send a message.
753    pub async fn send_raw(&mut self, msg: Message) -> Result<(), WebsocketError> {
754        let real = match msg {
755            Message::Text(text) => axum::extract::ws::Message::Text(text.into()),
756            Message::Binary(bytes) => axum::extract::ws::Message::Binary(bytes),
757            Message::Ping(bytes) => axum::extract::ws::Message::Ping(bytes),
758            Message::Pong(bytes) => axum::extract::ws::Message::Pong(bytes),
759            Message::Close { code, reason } => {
760                axum::extract::ws::Message::Close(Some(axum::extract::ws::CloseFrame {
761                    code: code.into(),
762                    reason: reason.into(),
763                }))
764            }
765        };
766
767        self.inner
768            .send(real)
769            .await
770            .map_err(|_err| WebsocketError::AlreadyClosed)
771    }
772
773    /// Return the selected WebSocket subprotocol, if one has been chosen.
774    pub fn protocol(&self) -> Option<&http::HeaderValue> {
775        self.inner.protocol()
776    }
777
778    /// Get a mutable reference to the underlying Axum WebSocket.
779    pub fn socket(&mut self) -> &mut axum::extract::ws::WebSocket {
780        &mut self.inner
781    }
782}
783
784#[cfg(feature = "server")]
785impl<In: DeserializeOwned, Out: Serialize, E: Encoding> Stream for TypedWebsocket<In, Out, E> {
786    type Item = Result<In, WebsocketError>;
787
788    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
789        use axum::extract::ws::Message as AxumMessage;
790
791        loop {
792            match ready!(self.inner.poll_next_unpin(cx)) {
793                Some(Ok(msg)) => match msg {
794                    AxumMessage::Text(utf8_bytes) => {
795                        let e: In = E::decode(utf8_bytes.into())
796                            .ok_or_else(WebsocketError::deserialization)?;
797                        return Poll::Ready(Some(Ok(e)));
798                    }
799                    AxumMessage::Binary(bytes) => {
800                        let e: In = E::decode(bytes).ok_or_else(WebsocketError::deserialization)?;
801                        return Poll::Ready(Some(Ok(e)));
802                    }
803
804                    AxumMessage::Close(Some(close_frame)) => {
805                        return Poll::Ready(Some(Err(WebsocketError::ConnectionClosed {
806                            code: close_frame.code.into(),
807                            description: close_frame.reason.to_string(),
808                        })));
809                    }
810                    AxumMessage::Close(None) => {
811                        return Poll::Ready(Some(Err(WebsocketError::AlreadyClosed)));
812                    }
813
814                    AxumMessage::Ping(_bytes) => continue,
815                    AxumMessage::Pong(_bytes) => continue,
816                },
817                Some(Err(_)) => {
818                    return Poll::Ready(Some(Err(WebsocketError::closed_away())));
819                }
820                None => return Poll::Ready(None),
821            }
822        }
823    }
824}
825
826#[cfg(feature = "server")]
827impl<In: DeserializeOwned, Out: Serialize, E: Encoding> Sink<Out> for TypedWebsocket<In, Out, E> {
828    type Error = WebsocketError;
829
830    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
831        Pin::new(&mut self.inner)
832            .poll_ready(cx)
833            .map_err(|_| WebsocketError::AlreadyClosed)
834    }
835
836    fn start_send(mut self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
837        use axum::extract::ws::Message;
838
839        let to_bytes = E::to_bytes(&item).ok_or_else(|| {
840            WebsocketError::Serialization(anyhow::anyhow!("Failed to serialize message").into())
841        })?;
842
843        Pin::new(&mut self.inner)
844            .start_send(Message::Binary(to_bytes))
845            .map_err(|_| WebsocketError::AlreadyClosed)
846    }
847
848    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
849        Pin::new(&mut self.inner)
850            .poll_flush(cx)
851            .map_err(|_| WebsocketError::AlreadyClosed)
852    }
853
854    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
855        Pin::new(&mut self.inner)
856            .poll_close(cx)
857            .map_err(|_| WebsocketError::AlreadyClosed)
858    }
859}
860
861#[derive(thiserror::Error, Debug)]
862pub enum WebsocketError {
863    #[error("Connection closed")]
864    ConnectionClosed {
865        code: CloseCode,
866        description: String,
867    },
868
869    #[error("WebSocket already closed")]
870    AlreadyClosed,
871
872    #[error("WebSocket capacity reached")]
873    Capacity,
874
875    #[error("An unexpected internal error occurred")]
876    Unexpected,
877
878    #[error("WebSocket is not initialized on this platform")]
879    Uninitialized,
880
881    #[cfg(not(target_arch = "wasm32"))]
882    #[error("websocket upgrade failed")]
883    Handshake(#[from] native::HandshakeError),
884
885    #[error("reqwest error")]
886    Reqwest(#[from] reqwest::Error),
887
888    #[cfg(not(target_arch = "wasm32"))]
889    #[error("tungstenite error")]
890    Tungstenite(#[from] tungstenite::Error),
891
892    /// Error during serialization/deserialization.
893    #[error("error during serialization/deserialization")]
894    Deserialization(Box<dyn std::error::Error + Send + Sync>),
895
896    /// Error during serialization/deserialization.
897    #[error("error during serialization/deserialization")]
898    Serialization(Box<dyn std::error::Error + Send + Sync>),
899
900    /// Error during serialization/deserialization.
901    #[error("serde_json error")]
902    Json(#[from] serde_json::Error),
903
904    /// Error during serialization/deserialization.
905    #[error("ciborium error")]
906    Cbor(#[from] ciborium::de::Error<std::io::Error>),
907}
908
909#[cfg(feature = "web")]
910impl From<gloo_net::websocket::WebSocketError> for WebsocketError {
911    fn from(value: gloo_net::websocket::WebSocketError) -> Self {
912        use gloo_net::websocket::WebSocketError;
913        match value {
914            WebSocketError::ConnectionError => WebsocketError::AlreadyClosed,
915            WebSocketError::ConnectionClose(close_event) => WebsocketError::ConnectionClosed {
916                code: close_event.code.into(),
917                description: close_event.reason,
918            },
919            WebSocketError::MessageSendError(_js_error) => WebsocketError::Unexpected,
920            _ => WebsocketError::Unexpected,
921        }
922    }
923}
924
925impl WebsocketError {
926    pub fn closed_away() -> Self {
927        Self::ConnectionClosed {
928            code: CloseCode::Normal,
929            description: "Connection closed normally".into(),
930        }
931    }
932
933    pub fn deserialization() -> Self {
934        Self::Deserialization(anyhow::anyhow!("Failed to deserialize message").into())
935    }
936
937    pub fn serialization() -> Self {
938        Self::Serialization(anyhow::anyhow!("Failed to serialize message").into())
939    }
940}
941
942#[cfg(feature = "web")]
943struct WebsysSocket {
944    sender: Mutex<SplitSink<WsWebsocket, WsMessage>>,
945    receiver: Mutex<SplitStream<WsWebsocket>>,
946}
947
948/// A `WebSocket` message, which can be a text string or binary data.
949#[derive(Clone, Debug)]
950pub enum Message {
951    /// A text `WebSocket` message.
952    // note: we can't use `tungstenite::Utf8String` here, since we don't have tungstenite on wasm.
953    Text(String),
954
955    /// A binary `WebSocket` message.
956    Binary(Bytes),
957
958    /// A ping message with the specified payload.
959    ///
960    /// The payload here must have a length less than 125 bytes.
961    ///
962    /// # WASM
963    ///
964    /// This variant is ignored for WASM targets.
965    Ping(Bytes),
966
967    /// A pong message with the specified payload.
968    ///
969    /// The payload here must have a length less than 125 bytes.
970    ///
971    /// # WASM
972    ///
973    /// This variant is ignored for WASM targets.
974    Pong(Bytes),
975
976    /// A close message.
977    ///
978    /// Sending this will not close the connection, though the remote peer will likely close the connection after receiving this.
979    Close { code: CloseCode, reason: String },
980}
981
982impl From<String> for Message {
983    #[inline]
984    fn from(value: String) -> Self {
985        Self::Text(value)
986    }
987}
988
989impl From<&str> for Message {
990    #[inline]
991    fn from(value: &str) -> Self {
992        Self::from(value.to_owned())
993    }
994}
995
996impl From<Bytes> for Message {
997    #[inline]
998    fn from(value: Bytes) -> Self {
999        Self::Binary(value)
1000    }
1001}
1002
1003impl From<Vec<u8>> for Message {
1004    #[inline]
1005    fn from(value: Vec<u8>) -> Self {
1006        Self::from(Bytes::from(value))
1007    }
1008}
1009
1010impl From<&[u8]> for Message {
1011    #[inline]
1012    fn from(value: &[u8]) -> Self {
1013        Self::from(Bytes::copy_from_slice(value))
1014    }
1015}
1016
1017/// Status code used to indicate why an endpoint is closing the `WebSocket`
1018/// connection.[1]
1019///
1020/// [1]: https://datatracker.ietf.org/doc/html/rfc6455
1021#[derive(Debug, Default, Eq, PartialEq, Clone, Copy)]
1022#[non_exhaustive]
1023pub enum CloseCode {
1024    /// Indicates a normal closure, meaning that the purpose for
1025    /// which the connection was established has been fulfilled.
1026    #[default]
1027    Normal,
1028
1029    /// Indicates that an endpoint is "going away", such as a server
1030    /// going down or a browser having navigated away from a page.
1031    Away,
1032
1033    /// Indicates that an endpoint is terminating the connection due
1034    /// to a protocol error.
1035    Protocol,
1036
1037    /// Indicates that an endpoint is terminating the connection
1038    /// because it has received a type of data it cannot accept (e.g., an
1039    /// endpoint that understands only text data MAY send this if it
1040    /// receives a binary message).
1041    Unsupported,
1042
1043    /// Indicates that no status code was included in a closing frame. This
1044    /// close code makes it possible to use a single method, `on_close` to
1045    /// handle even cases where no close code was provided.
1046    Status,
1047
1048    /// Indicates an abnormal closure. If the abnormal closure was due to an
1049    /// error, this close code will not be used. Instead, the `on_error` method
1050    /// of the handler will be called with the error. However, if the connection
1051    /// is simply dropped, without an error, this close code will be sent to the
1052    /// handler.
1053    Abnormal,
1054
1055    /// Indicates that an endpoint is terminating the connection
1056    /// because it has received data within a message that was not
1057    /// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\]
1058    /// data within a text message).
1059    Invalid,
1060
1061    /// Indicates that an endpoint is terminating the connection
1062    /// because it has received a message that violates its policy.  This
1063    /// is a generic status code that can be returned when there is no
1064    /// other more suitable status code (e.g., Unsupported or Size) or if there
1065    /// is a need to hide specific details about the policy.
1066    Policy,
1067
1068    /// Indicates that an endpoint is terminating the connection
1069    /// because it has received a message that is too big for it to
1070    /// process.
1071    Size,
1072
1073    /// Indicates that an endpoint (client) is terminating the
1074    /// connection because it has expected the server to negotiate one or
1075    /// more extension, but the server didn't return them in the response
1076    /// message of the `WebSocket` handshake.  The list of extensions that
1077    /// are needed should be given as the reason for closing.
1078    /// Note that this status code is not used by the server, because it
1079    /// can fail the `WebSocket` handshake instead.
1080    Extension,
1081
1082    /// Indicates that a server is terminating the connection because
1083    /// it encountered an unexpected condition that prevented it from
1084    /// fulfilling the request.
1085    Error,
1086
1087    /// Indicates that the server is restarting. A client may choose to
1088    /// reconnect, and if it does, it should use a randomized delay of 5-30
1089    /// seconds between attempts.
1090    Restart,
1091
1092    /// Indicates that the server is overloaded and the client should either
1093    /// connect to a different IP (when multiple targets exist), or
1094    /// reconnect to the same IP when a user has performed an action.
1095    Again,
1096
1097    /// Indicates that the connection was closed due to a failure to perform a
1098    /// TLS handshake (e.g., the server certificate can't be verified). This
1099    /// is a reserved value and MUST NOT be set as a status code in a Close
1100    /// control frame by an endpoint.
1101    Tls,
1102
1103    /// Reserved status codes.
1104    Reserved(u16),
1105
1106    /// Reserved for use by libraries, frameworks, and applications. These
1107    /// status codes are registered directly with IANA. The interpretation of
1108    /// these codes is undefined by the `WebSocket` protocol.
1109    Iana(u16),
1110
1111    /// Reserved for private use. These can't be registered and can be used by
1112    /// prior agreements between `WebSocket` applications. The interpretation of
1113    /// these codes is undefined by the `WebSocket` protocol.
1114    Library(u16),
1115
1116    /// Unused / invalid status codes.
1117    Bad(u16),
1118}
1119
1120impl CloseCode {
1121    /// Check if this `CloseCode` is allowed.
1122    #[must_use]
1123    pub const fn is_allowed(self) -> bool {
1124        !matches!(
1125            self,
1126            Self::Bad(_) | Self::Reserved(_) | Self::Status | Self::Abnormal | Self::Tls
1127        )
1128    }
1129}
1130
1131impl std::fmt::Display for CloseCode {
1132    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1133        let code: u16 = (*self).into();
1134        write!(f, "{code}")
1135    }
1136}
1137
1138impl From<CloseCode> for u16 {
1139    fn from(code: CloseCode) -> Self {
1140        match code {
1141            CloseCode::Normal => 1000,
1142            CloseCode::Away => 1001,
1143            CloseCode::Protocol => 1002,
1144            CloseCode::Unsupported => 1003,
1145            CloseCode::Status => 1005,
1146            CloseCode::Abnormal => 1006,
1147            CloseCode::Invalid => 1007,
1148            CloseCode::Policy => 1008,
1149            CloseCode::Size => 1009,
1150            CloseCode::Extension => 1010,
1151            CloseCode::Error => 1011,
1152            CloseCode::Restart => 1012,
1153            CloseCode::Again => 1013,
1154            CloseCode::Tls => 1015,
1155            CloseCode::Reserved(code)
1156            | CloseCode::Iana(code)
1157            | CloseCode::Library(code)
1158            | CloseCode::Bad(code) => code,
1159        }
1160    }
1161}
1162
1163impl From<u16> for CloseCode {
1164    fn from(code: u16) -> Self {
1165        match code {
1166            1000 => Self::Normal,
1167            1001 => Self::Away,
1168            1002 => Self::Protocol,
1169            1003 => Self::Unsupported,
1170            1005 => Self::Status,
1171            1006 => Self::Abnormal,
1172            1007 => Self::Invalid,
1173            1008 => Self::Policy,
1174            1009 => Self::Size,
1175            1010 => Self::Extension,
1176            1011 => Self::Error,
1177            1012 => Self::Restart,
1178            1013 => Self::Again,
1179            1015 => Self::Tls,
1180            1016..=2999 => Self::Reserved(code),
1181            3000..=3999 => Self::Iana(code),
1182            4000..=4999 => Self::Library(code),
1183            _ => Self::Bad(code),
1184        }
1185    }
1186}
1187
1188#[cfg(not(target_arch = "wasm32"))]
1189mod native {
1190    use std::borrow::Cow;
1191
1192    use crate::ClientRequest;
1193
1194    use super::{CloseCode, Message, WebsocketError};
1195    use dioxus_fullstack_core::RequestError;
1196    use reqwest::{
1197        header::{HeaderName, HeaderValue},
1198        Response, StatusCode, Version,
1199    };
1200    use tungstenite::protocol::WebSocketConfig;
1201
1202    pub(crate) struct SplitSocket {
1203        pub sender: futures_util::lock::Mutex<
1204            async_tungstenite::WebSocketSender<tokio_util::compat::Compat<reqwest::Upgraded>>,
1205        >,
1206
1207        pub receiver: futures_util::lock::Mutex<
1208            async_tungstenite::WebSocketReceiver<tokio_util::compat::Compat<reqwest::Upgraded>>,
1209        >,
1210    }
1211
1212    pub async fn send_request(
1213        request: ClientRequest,
1214        protocols: &[String],
1215    ) -> Result<WebSocketResponse, WebsocketError> {
1216        let request_builder = request.new_reqwest_request();
1217        let (client, request_result) = request_builder.build_split();
1218        let mut request = request_result?;
1219
1220        // change the scheme from wss? to https?
1221        let url = request.url_mut();
1222        match url.scheme() {
1223            "ws" => {
1224                url.set_scheme("http")
1225                    .expect("url should accept http scheme");
1226            }
1227            "wss" => {
1228                url.set_scheme("https")
1229                    .expect("url should accept https scheme");
1230            }
1231            _ => {}
1232        }
1233
1234        // prepare request
1235        let version = request.version();
1236        let nonce = match version {
1237            Version::HTTP_10 | Version::HTTP_11 => {
1238                // HTTP 1 requires us to set some headers.
1239                let nonce_value = tungstenite::handshake::client::generate_key();
1240                let headers = request.headers_mut();
1241                headers.insert(
1242                    reqwest::header::CONNECTION,
1243                    HeaderValue::from_static("upgrade"),
1244                );
1245                headers.insert(
1246                    reqwest::header::UPGRADE,
1247                    HeaderValue::from_static("websocket"),
1248                );
1249                headers.insert(
1250                    reqwest::header::SEC_WEBSOCKET_KEY,
1251                    HeaderValue::from_str(&nonce_value).expect("nonce is a invalid header value"),
1252                );
1253                headers.insert(
1254                    reqwest::header::SEC_WEBSOCKET_VERSION,
1255                    HeaderValue::from_static("13"),
1256                );
1257                if !protocols.is_empty() {
1258                    headers.insert(
1259                        reqwest::header::SEC_WEBSOCKET_PROTOCOL,
1260                        HeaderValue::from_str(&protocols.join(", "))
1261                            .expect("protocols is an invalid header value"),
1262                    );
1263                }
1264
1265                Some(nonce_value)
1266            }
1267            Version::HTTP_2 => {
1268                // TODO: Implement websocket upgrade for HTTP 2.
1269                return Err(HandshakeError::UnsupportedHttpVersion(version).into());
1270            }
1271            _ => {
1272                return Err(HandshakeError::UnsupportedHttpVersion(version).into());
1273            }
1274        };
1275
1276        // execute request
1277        let response = client.execute(request).await?;
1278
1279        Ok(WebSocketResponse {
1280            response,
1281            version,
1282            nonce,
1283        })
1284    }
1285
1286    pub type WebSocketStream =
1287        async_tungstenite::WebSocketStream<tokio_util::compat::Compat<reqwest::Upgraded>>;
1288
1289    /// Error during `Websocket` handshake.
1290    #[derive(Debug, thiserror::Error, Clone)]
1291    pub enum HandshakeError {
1292        #[error("unsupported http version: {0:?}")]
1293        UnsupportedHttpVersion(Version),
1294
1295        #[error(
1296            "the server responded with a different http version. this could be the case because reqwest silently upgraded the connection to http2. see: https://github.com/jgraef/reqwest-websocket/issues/2"
1297        )]
1298        ServerRespondedWithDifferentVersion,
1299
1300        #[error("missing header {header}")]
1301        MissingHeader { header: HeaderName },
1302
1303        #[error("unexpected value for header {header}: expected {expected}, but got {got:?}.")]
1304        UnexpectedHeaderValue {
1305            header: HeaderName,
1306            got: HeaderValue,
1307            expected: Cow<'static, str>,
1308        },
1309
1310        #[error("expected the server to select a protocol.")]
1311        ExpectedAProtocol,
1312
1313        #[error("unexpected protocol: {got}")]
1314        UnexpectedProtocol { got: String },
1315
1316        #[error("unexpected status code: {0}")]
1317        UnexpectedStatusCode(StatusCode),
1318    }
1319
1320    pub struct WebSocketResponse {
1321        pub response: Response,
1322        pub version: Version,
1323        pub nonce: Option<String>,
1324    }
1325
1326    impl WebSocketResponse {
1327        pub async fn into_stream_and_protocol(
1328            self,
1329            protocols: Vec<String>,
1330            web_socket_config: Option<WebSocketConfig>,
1331        ) -> Result<(SplitSocket, Option<String>), WebsocketError> {
1332            let headers = self.response.headers();
1333
1334            if self.response.version() != self.version {
1335                return Err(HandshakeError::ServerRespondedWithDifferentVersion.into());
1336            }
1337
1338            if self.response.status() != reqwest::StatusCode::SWITCHING_PROTOCOLS {
1339                tracing::debug!(status_code = %self.response.status(), "server responded with unexpected status code");
1340                return Err(HandshakeError::UnexpectedStatusCode(self.response.status()).into());
1341            }
1342
1343            if let Some(header) = headers.get(reqwest::header::CONNECTION) {
1344                if !header
1345                    .to_str()
1346                    .is_ok_and(|s| s.eq_ignore_ascii_case("upgrade"))
1347                {
1348                    tracing::debug!("server responded with invalid Connection header: {header:?}");
1349                    return Err(HandshakeError::UnexpectedHeaderValue {
1350                        header: reqwest::header::CONNECTION,
1351                        got: header.clone(),
1352                        expected: "upgrade".into(),
1353                    }
1354                    .into());
1355                }
1356            } else {
1357                tracing::debug!("missing Connection header");
1358                return Err(HandshakeError::MissingHeader {
1359                    header: reqwest::header::CONNECTION,
1360                }
1361                .into());
1362            }
1363
1364            if let Some(header) = headers.get(reqwest::header::UPGRADE) {
1365                if !header
1366                    .to_str()
1367                    .is_ok_and(|s| s.eq_ignore_ascii_case("websocket"))
1368                {
1369                    tracing::debug!("server responded with invalid Upgrade header: {header:?}");
1370                    return Err(HandshakeError::UnexpectedHeaderValue {
1371                        header: reqwest::header::UPGRADE,
1372                        got: header.clone(),
1373                        expected: "websocket".into(),
1374                    }
1375                    .into());
1376                }
1377            } else {
1378                tracing::debug!("missing Upgrade header");
1379                return Err(HandshakeError::MissingHeader {
1380                    header: reqwest::header::UPGRADE,
1381                }
1382                .into());
1383            }
1384
1385            if let Some(nonce) = &self.nonce {
1386                let expected_nonce = tungstenite::handshake::derive_accept_key(nonce.as_bytes());
1387
1388                if let Some(header) = headers.get(reqwest::header::SEC_WEBSOCKET_ACCEPT) {
1389                    if !header.to_str().is_ok_and(|s| s == expected_nonce) {
1390                        tracing::debug!(
1391                            "server responded with invalid Sec-Websocket-Accept header: {header:?}"
1392                        );
1393                        return Err(HandshakeError::UnexpectedHeaderValue {
1394                            header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
1395                            got: header.clone(),
1396                            expected: expected_nonce.into(),
1397                        }
1398                        .into());
1399                    }
1400                } else {
1401                    tracing::debug!("missing Sec-Websocket-Accept header");
1402                    return Err(HandshakeError::MissingHeader {
1403                        header: reqwest::header::SEC_WEBSOCKET_ACCEPT,
1404                    }
1405                    .into());
1406                }
1407            }
1408
1409            let protocol = headers
1410                .get(reqwest::header::SEC_WEBSOCKET_PROTOCOL)
1411                .and_then(|v| v.to_str().ok())
1412                .map(ToOwned::to_owned);
1413
1414            match (protocols.is_empty(), &protocol) {
1415                (true, None) => {
1416                    // we didn't request any protocols, so we don't expect one
1417                    // in return
1418                }
1419                (false, None) => {
1420                    // server didn't reply with a protocol
1421                    return Err(HandshakeError::ExpectedAProtocol.into());
1422                }
1423                (false, Some(protocol)) => {
1424                    if !protocols.contains(protocol) {
1425                        // the responded protocol is none which we requested
1426                        return Err(HandshakeError::UnexpectedProtocol {
1427                            got: protocol.clone(),
1428                        }
1429                        .into());
1430                    }
1431                }
1432                (true, Some(protocol)) => {
1433                    // we didn't request any protocols but got one anyway
1434                    return Err(HandshakeError::UnexpectedProtocol {
1435                        got: protocol.clone(),
1436                    }
1437                    .into());
1438                }
1439            }
1440
1441            use tokio_util::compat::TokioAsyncReadCompatExt;
1442
1443            let inner = WebSocketStream::from_raw_socket(
1444                self.response.upgrade().await?.compat(),
1445                tungstenite::protocol::Role::Client,
1446                web_socket_config,
1447            )
1448            .await;
1449
1450            let split: (
1451                async_tungstenite::WebSocketSender<tokio_util::compat::Compat<reqwest::Upgraded>>,
1452                async_tungstenite::WebSocketReceiver<tokio_util::compat::Compat<reqwest::Upgraded>>,
1453            ) = inner.split();
1454
1455            let split_socket = SplitSocket {
1456                sender: futures_util::lock::Mutex::new(split.0),
1457                receiver: futures_util::lock::Mutex::new(split.1),
1458            };
1459
1460            Ok((split_socket, protocol))
1461        }
1462    }
1463
1464    #[derive(Debug, thiserror::Error)]
1465    #[error("could not convert message")]
1466    pub struct FromTungsteniteMessageError {
1467        pub original: tungstenite::Message,
1468    }
1469
1470    impl TryFrom<tungstenite::Message> for Message {
1471        type Error = FromTungsteniteMessageError;
1472
1473        fn try_from(value: tungstenite::Message) -> Result<Self, Self::Error> {
1474            match value {
1475                tungstenite::Message::Text(text) => Ok(Self::Text(text.as_str().to_owned())),
1476                tungstenite::Message::Binary(data) => Ok(Self::Binary(data)),
1477                tungstenite::Message::Ping(data) => Ok(Self::Ping(data)),
1478                tungstenite::Message::Pong(data) => Ok(Self::Pong(data)),
1479                tungstenite::Message::Close(Some(tungstenite::protocol::CloseFrame {
1480                    code,
1481                    reason,
1482                })) => Ok(Self::Close {
1483                    code: code.into(),
1484                    reason: reason.as_str().to_owned(),
1485                }),
1486                tungstenite::Message::Close(None) => Ok(Self::Close {
1487                    code: CloseCode::default(),
1488                    reason: "".to_owned(),
1489                }),
1490                tungstenite::Message::Frame(_) => {
1491                    Err(FromTungsteniteMessageError { original: value })
1492                }
1493            }
1494        }
1495    }
1496
1497    impl From<Message> for tungstenite::Message {
1498        fn from(value: Message) -> Self {
1499            match value {
1500                Message::Text(text) => Self::Text(tungstenite::Utf8Bytes::from(text)),
1501                Message::Binary(data) => Self::Binary(data),
1502                Message::Ping(data) => Self::Ping(data),
1503                Message::Pong(data) => Self::Pong(data),
1504                Message::Close { code, reason } => {
1505                    Self::Close(Some(tungstenite::protocol::CloseFrame {
1506                        code: code.into(),
1507                        reason: reason.into(),
1508                    }))
1509                }
1510            }
1511        }
1512    }
1513
1514    impl From<tungstenite::protocol::frame::coding::CloseCode> for CloseCode {
1515        fn from(value: tungstenite::protocol::frame::coding::CloseCode) -> Self {
1516            u16::from(value).into()
1517        }
1518    }
1519
1520    impl From<CloseCode> for tungstenite::protocol::frame::coding::CloseCode {
1521        fn from(value: CloseCode) -> Self {
1522            u16::from(value).into()
1523        }
1524    }
1525
1526    impl From<HandshakeError> for RequestError {
1527        fn from(value: HandshakeError) -> Self {
1528            let string = value.to_string();
1529            match value {
1530                HandshakeError::UnexpectedStatusCode(status) => {
1531                    Self::Status(string, status.as_u16())
1532                }
1533                HandshakeError::UnsupportedHttpVersion(_)
1534                | HandshakeError::MissingHeader { .. }
1535                | HandshakeError::UnexpectedHeaderValue { .. }
1536                | HandshakeError::ExpectedAProtocol
1537                | HandshakeError::UnexpectedProtocol { .. }
1538                | HandshakeError::ServerRespondedWithDifferentVersion => Self::Connect(string),
1539            }
1540        }
1541    }
1542
1543    trait IntoRequestError {
1544        fn into_request_error(self) -> RequestError;
1545    }
1546
1547    impl IntoRequestError for reqwest::Error {
1548        fn into_request_error(self) -> RequestError {
1549            const DEFAULT_STATUS_CODE: u16 = 0;
1550            let string = self.to_string();
1551            if self.is_builder() {
1552                RequestError::Builder(string)
1553            } else if self.is_redirect() {
1554                RequestError::Redirect(string)
1555            } else if self.is_status() {
1556                RequestError::Status(
1557                    string,
1558                    self.status()
1559                        .as_ref()
1560                        .map(StatusCode::as_u16)
1561                        .unwrap_or(DEFAULT_STATUS_CODE),
1562                )
1563            } else if self.is_body() {
1564                RequestError::Body(string)
1565            } else if self.is_decode() {
1566                RequestError::Decode(string)
1567            } else if self.is_upgrade() {
1568                RequestError::Connect(string)
1569            } else {
1570                RequestError::Request(string)
1571            }
1572        }
1573    }
1574
1575    impl IntoRequestError for tungstenite::Error {
1576        fn into_request_error(self) -> RequestError {
1577            match self {
1578                tungstenite::Error::ConnectionClosed => {
1579                    RequestError::Connect("websocket connection closed".to_owned())
1580                }
1581                tungstenite::Error::AlreadyClosed => {
1582                    RequestError::Connect("websocket already closed".to_owned())
1583                }
1584                tungstenite::Error::Io(error) => RequestError::Connect(error.to_string()),
1585                tungstenite::Error::Tls(error) => RequestError::Connect(error.to_string()),
1586                tungstenite::Error::Capacity(error) => RequestError::Body(error.to_string()),
1587                tungstenite::Error::Protocol(error) => RequestError::Request(error.to_string()),
1588                tungstenite::Error::WriteBufferFull(message) => {
1589                    RequestError::Body(message.to_string())
1590                }
1591                tungstenite::Error::Utf8(error) => RequestError::Decode(error),
1592                tungstenite::Error::AttackAttempt => {
1593                    RequestError::Request("Tungstenite attack attempt detected".to_owned())
1594                }
1595                tungstenite::Error::Url(error) => RequestError::Builder(error.to_string()),
1596                tungstenite::Error::Http(response) => {
1597                    let status_code = response.status();
1598                    RequestError::Status(format!("HTTP error: {status_code}"), status_code.as_u16())
1599                }
1600                tungstenite::Error::HttpFormat(error) => RequestError::Builder(error.to_string()),
1601            }
1602        }
1603    }
1604
1605    impl From<WebsocketError> for RequestError {
1606        fn from(value: WebsocketError) -> Self {
1607            match value {
1608                WebsocketError::ConnectionClosed { code, description } => {
1609                    Self::Connect(format!("connection closed ({code}): {description}"))
1610                }
1611                WebsocketError::AlreadyClosed => Self::Connect(value.to_string()),
1612                WebsocketError::Capacity => Self::Body(value.to_string()),
1613                WebsocketError::Unexpected => Self::Request(value.to_string()),
1614                WebsocketError::Uninitialized => Self::Builder(value.to_string()),
1615                WebsocketError::Handshake(error) => error.into(),
1616                WebsocketError::Reqwest(error) => error.into_request_error(),
1617                WebsocketError::Tungstenite(error) => error.into_request_error(),
1618                WebsocketError::Serialization(error) => Self::Serialization(error.to_string()),
1619                WebsocketError::Deserialization(error) => Self::Decode(error.to_string()),
1620                WebsocketError::Json(error) => Self::Decode(error.to_string()),
1621                WebsocketError::Cbor(error) => Self::Decode(error.to_string()),
1622            }
1623        }
1624    }
1625}