barter_integration/protocol/
websocket.rs

1use crate::{error::SocketError, protocol::StreamParser};
2use bytes::Bytes;
3use serde::{Deserialize, Serialize, de::DeserializeOwned};
4use std::fmt::Debug;
5use tokio::net::TcpStream;
6use tokio_tungstenite::{
7    MaybeTlsStream, connect_async,
8    tungstenite::{
9        Utf8Bytes,
10        client::IntoClientRequest,
11        error::ProtocolError,
12        protocol::{CloseFrame, frame::Frame},
13    },
14};
15use tracing::debug;
16
17/// Convenient type alias for a tungstenite `WebSocketStream`.
18pub type WebSocket = tokio_tungstenite::WebSocketStream<MaybeTlsStream<TcpStream>>;
19
20/// Convenient type alias for the `Sink` half of a tungstenite [`WebSocket`].
21pub type WsSink = futures::stream::SplitSink<WebSocket, WsMessage>;
22
23/// Convenient type alias for the `Stream` half of a tungstenite [`WebSocket`].
24pub type WsStream = futures::stream::SplitStream<WebSocket>;
25
26/// Communicative type alias for a tungstenite [`WebSocket`] `Message`.
27pub type WsMessage = tokio_tungstenite::tungstenite::Message;
28
29/// Communicative type alias for a tungstenite [`WebSocket`] `Error`.
30pub type WsError = tokio_tungstenite::tungstenite::Error;
31
32/// Default [`StreamParser`] implementation for a [`WebSocket`].
33#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Deserialize, Serialize)]
34pub struct WebSocketParser;
35
36impl StreamParser for WebSocketParser {
37    type Stream = WebSocket;
38    type Message = WsMessage;
39    type Error = WsError;
40
41    fn parse<Output>(
42        input: Result<Self::Message, Self::Error>,
43    ) -> Option<Result<Output, SocketError>>
44    where
45        Output: DeserializeOwned,
46    {
47        match input {
48            Ok(ws_message) => match ws_message {
49                WsMessage::Text(text) => process_text(text),
50                WsMessage::Binary(binary) => process_binary(binary),
51                WsMessage::Ping(ping) => process_ping(ping),
52                WsMessage::Pong(pong) => process_pong(pong),
53                WsMessage::Close(close_frame) => process_close_frame(close_frame),
54                WsMessage::Frame(frame) => process_frame(frame),
55            },
56            Err(ws_err) => Some(Err(SocketError::WebSocket(Box::new(ws_err)))),
57        }
58    }
59}
60
61/// Process a payload of `String` by deserialising into an `ExchangeMessage`.
62pub fn process_text<ExchangeMessage>(
63    payload: Utf8Bytes,
64) -> Option<Result<ExchangeMessage, SocketError>>
65where
66    ExchangeMessage: DeserializeOwned,
67{
68    Some(
69        serde_json::from_str::<ExchangeMessage>(&payload).map_err(|error| {
70            debug!(
71                ?error,
72                ?payload,
73                action = "returning Some(Err(err))",
74                "failed to deserialize WebSocket Message into domain specific Message"
75            );
76            SocketError::Deserialise {
77                error,
78                payload: payload.to_string(),
79            }
80        }),
81    )
82}
83
84/// Process a payload of `Vec<u8>` bytes by deserialising into an `ExchangeMessage`.
85pub fn process_binary<ExchangeMessage>(
86    payload: Bytes,
87) -> Option<Result<ExchangeMessage, SocketError>>
88where
89    ExchangeMessage: DeserializeOwned,
90{
91    Some(
92        serde_json::from_slice::<ExchangeMessage>(&payload).map_err(|error| {
93            debug!(
94                ?error,
95                ?payload,
96                action = "returning Some(Err(err))",
97                "failed to deserialize WebSocket Message into domain specific Message"
98            );
99            SocketError::Deserialise {
100                error,
101                payload: String::from_utf8(payload.into()).unwrap_or_else(|x| x.to_string()),
102            }
103        }),
104    )
105}
106
107/// Basic process for a [`WebSocket`] ping message. Logs the payload at `trace` level.
108pub fn process_ping<ExchangeMessage>(ping: Bytes) -> Option<Result<ExchangeMessage, SocketError>> {
109    debug!(payload = ?ping, "received Ping WebSocket message");
110    None
111}
112
113/// Basic process for a [`WebSocket`] pong message. Logs the payload at `trace` level.
114pub fn process_pong<ExchangeMessage>(pong: Bytes) -> Option<Result<ExchangeMessage, SocketError>> {
115    debug!(payload = ?pong, "received Pong WebSocket message");
116    None
117}
118
119/// Basic process for a [`WebSocket`] CloseFrame message. Logs the payload at `trace` level.
120pub fn process_close_frame<ExchangeMessage>(
121    close_frame: Option<CloseFrame>,
122) -> Option<Result<ExchangeMessage, SocketError>> {
123    let close_frame = format!("{close_frame:?}");
124    debug!(payload = %close_frame, "received CloseFrame WebSocket message");
125    Some(Err(SocketError::Terminated(close_frame)))
126}
127
128/// Basic process for a [`WebSocket`] Frame message. Logs the payload at `trace` level.
129pub fn process_frame<ExchangeMessage>(
130    frame: Frame,
131) -> Option<Result<ExchangeMessage, SocketError>> {
132    let frame = format!("{frame:?}");
133    debug!(payload = %frame, "received unexpected Frame WebSocket message");
134    None
135}
136
137/// Connect asynchronously to a [`WebSocket`] server.
138pub async fn connect<R>(request: R) -> Result<WebSocket, SocketError>
139where
140    R: IntoClientRequest + Unpin + Debug,
141{
142    debug!(?request, "attempting to establish WebSocket connection");
143    connect_async(request)
144        .await
145        .map(|(websocket, _)| websocket)
146        .map_err(|error| SocketError::WebSocket(Box::new(error)))
147}
148
149/// Determine whether a [`WsError`] indicates the [`WebSocket`] has disconnected.
150pub fn is_websocket_disconnected(error: &WsError) -> bool {
151    matches!(
152        error,
153        WsError::ConnectionClosed
154            | WsError::AlreadyClosed
155            | WsError::Io(_)
156            | WsError::Protocol(ProtocolError::SendAfterClosing)
157    )
158}