barter_integration_copy/protocol/
websocket.rs

1use crate::{error::SocketError, protocol::StreamParser};
2use serde::{de::DeserializeOwned, Deserialize, Serialize};
3use std::fmt::Debug;
4use tokio::net::TcpStream;
5use tokio_tungstenite::{
6    connect_async,
7    tungstenite::{
8        client::IntoClientRequest,
9        error::ProtocolError,
10        protocol::{frame::Frame, CloseFrame},
11    },
12    MaybeTlsStream,
13};
14use tracing::debug;
15
16/// Convenient type alias for a tungstenite `WebSocketStream`.
17pub type WebSocket = tokio_tungstenite::WebSocketStream<MaybeTlsStream<TcpStream>>;
18
19/// Convenient type alias for the `Sink` half of a tungstenite [`WebSocket`].
20pub type WsSink = futures::stream::SplitSink<WebSocket, WsMessage>;
21
22/// Convenient type alias for the `Stream` half of a tungstenite [`WebSocket`].
23pub type WsStream = futures::stream::SplitStream<WebSocket>;
24
25/// Communicative type alias for a tungstenite [`WebSocket`] `Message`.
26pub type WsMessage = tokio_tungstenite::tungstenite::Message;
27
28/// Communicative type alias for a tungstenite [`WebSocket`] `Error`.
29pub type WsError = tokio_tungstenite::tungstenite::Error;
30
31/// Default [`StreamParser`] implementation for a [`WebSocket`].
32#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Deserialize, Serialize)]
33pub struct WebSocketParser;
34
35impl StreamParser for WebSocketParser {
36    type Stream = WebSocket;
37    type Message = WsMessage;
38    type Error = WsError;
39
40    fn parse<Output>(
41        input: Result<Self::Message, Self::Error>,
42    ) -> Option<Result<Output, SocketError>>
43    where
44        Output: DeserializeOwned,
45    {
46        match input {
47            Ok(ws_message) => match ws_message {
48                WsMessage::Text(text) => process_text(text),
49                WsMessage::Binary(binary) => process_binary(binary),
50                WsMessage::Ping(ping) => process_ping(ping),
51                WsMessage::Pong(pong) => process_pong(pong),
52                WsMessage::Close(close_frame) => process_close_frame(close_frame),
53                WsMessage::Frame(frame) => process_frame(frame),
54            },
55            Err(ws_err) => Some(Err(SocketError::WebSocket(ws_err))),
56        }
57    }
58}
59
60/// Process a payload of `String` by deserialising into an `ExchangeMessage`.
61pub fn process_text<ExchangeMessage>(
62    payload: String,
63) -> Option<Result<ExchangeMessage, SocketError>>
64where
65    ExchangeMessage: DeserializeOwned,
66{
67    Some(
68        serde_json::from_str::<ExchangeMessage>(&payload).map_err(|error| {
69            debug!(
70                ?error,
71                ?payload,
72                action = "returning Some(Err(err))",
73                "failed to deserialize WebSocket Message into domain specific Message"
74            );
75            SocketError::Deserialise { error, payload }
76        }),
77    )
78}
79
80/// Process a payload of `Vec<u8>` bytes by deserialising into an `ExchangeMessage`.
81pub fn process_binary<ExchangeMessage>(
82    payload: Vec<u8>,
83) -> Option<Result<ExchangeMessage, SocketError>>
84where
85    ExchangeMessage: DeserializeOwned,
86{
87    Some(
88        serde_json::from_slice::<ExchangeMessage>(&payload).map_err(|error| {
89            debug!(
90                ?error,
91                ?payload,
92                action = "returning Some(Err(err))",
93                "failed to deserialize WebSocket Message into domain specific Message"
94            );
95            SocketError::Deserialise {
96                error,
97                payload: String::from_utf8(payload).unwrap_or_else(|x| x.to_string()),
98            }
99        }),
100    )
101}
102
103/// Basic process for a [`WebSocket`] ping message. Logs the payload at `trace` level.
104pub fn process_ping<ExchangeMessage>(
105    ping: Vec<u8>,
106) -> Option<Result<ExchangeMessage, SocketError>> {
107    debug!(payload = ?ping, "received Ping WebSocket message");
108    None
109}
110
111/// Basic process for a [`WebSocket`] pong message. Logs the payload at `trace` level.
112pub fn process_pong<ExchangeMessage>(
113    pong: Vec<u8>,
114) -> Option<Result<ExchangeMessage, SocketError>> {
115    debug!(payload = ?pong, "received Pong WebSocket message");
116    None
117}
118
119/// Basic process for a [`WebSocket`] CloseFrame message. Logs the payload at `trace` level.
120pub fn process_close_frame<ExchangeMessage>(
121    close_frame: Option<CloseFrame<'_>>,
122) -> Option<Result<ExchangeMessage, SocketError>> {
123    let close_frame = format!("{:?}", close_frame);
124    debug!(payload = %close_frame, "received CloseFrame WebSocket message");
125    Some(Err(SocketError::Terminated(close_frame)))
126}
127
128/// Basic process for a [`WebSocket`] Frame message. Logs the payload at `trace` level.
129pub fn process_frame<ExchangeMessage>(
130    frame: Frame,
131) -> Option<Result<ExchangeMessage, SocketError>> {
132    let frame = format!("{:?}", frame);
133    debug!(payload = %frame, "received unexpected Frame WebSocket message");
134    None
135}
136
137/// Connect asynchronously to a [`WebSocket`] server.
138pub async fn connect<R>(request: R) -> Result<WebSocket, SocketError>
139where
140    R: IntoClientRequest + Unpin + Debug,
141{
142    debug!(?request, "attempting to establish WebSocket connection");
143    connect_async(request)
144        .await
145        .map(|(websocket, _)| websocket)
146        .map_err(SocketError::WebSocket)
147}
148
149/// Determine whether a [`WsError`] indicates the [`WebSocket`] has disconnected.
150pub fn is_websocket_disconnected(error: &WsError) -> bool {
151    matches!(
152        error,
153        WsError::ConnectionClosed
154            | WsError::AlreadyClosed
155            | WsError::Io(_)
156            | WsError::Protocol(ProtocolError::SendAfterClosing)
157    )
158}