Skip to main content

barter_integration/protocol/websocket/
mod.rs

1use crate::{Message, error::SocketError, protocol::StreamParser};
2use bytes::Bytes;
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use tokio::net::TcpStream;
6use tokio_tungstenite::{
7    MaybeTlsStream, connect_async,
8    tungstenite::{
9        Utf8Bytes,
10        client::IntoClientRequest,
11        error::ProtocolError,
12        protocol::{CloseFrame, frame::Frame},
13    },
14};
15use tracing::debug;
16
17/// Convenient type alias for a tungstenite `WebSocketStream`.
18pub type WebSocket = tokio_tungstenite::WebSocketStream<MaybeTlsStream<TcpStream>>;
19
20/// Convenient type alias for the `Sink` half of a tungstenite [`WebSocket`].
21pub type WsSink = futures::stream::SplitSink<WebSocket, WsMessage>;
22
23/// Convenient type alias for the `Stream` half of a tungstenite [`WebSocket`].
24pub type WsStream = futures::stream::SplitStream<WebSocket>;
25
26/// Communicative type alias for a tungstenite [`WebSocket`] `Message`.
27pub type WsMessage = tokio_tungstenite::tungstenite::Message;
28
29/// Communicative type alias for a tungstenite [`WebSocket`] `Error`.
30pub type WsError = tokio_tungstenite::tungstenite::Error;
31
32/// [`WebSocket`] administration message variants.
33#[derive(Debug)]
34pub enum AdminWs {
35    Ping(Bytes),
36    Pong(Bytes),
37    Close(Option<CloseFrame>),
38    WsError(WsError),
39}
40
41/// [`WebSocket`] parser.
42///
43/// Translates a `Result<WsMessage, WsError>` to a `Message<AdminWs, Bytes>`.
44#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Deserialize, Serialize)]
45pub struct WsParser;
46
47impl WsParser {
48    pub fn parse(ws_result: Result<WsMessage, WsError>) -> Message<AdminWs, Bytes> {
49        match ws_result {
50            Ok(WsMessage::Text(utf8)) => Message::Payload(Bytes::from(utf8)),
51            Ok(WsMessage::Binary(bytes)) => Message::Payload(bytes),
52            Ok(WsMessage::Frame(frame)) => Message::Payload(frame.into_payload()),
53            Ok(WsMessage::Ping(bytes)) => Message::Admin(AdminWs::Ping(bytes)),
54            Ok(WsMessage::Pong(bytes)) => Message::Admin(AdminWs::Pong(bytes)),
55            Ok(WsMessage::Close(close)) => Message::Admin(AdminWs::Close(close)),
56            Err(error) => Message::Admin(AdminWs::WsError(error)),
57        }
58    }
59}
60
61/// Default [`StreamParser`] implementation for a [`WebSocket`].
62#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Deserialize, Serialize)]
63pub struct WebSocketSerdeParser;
64
65impl<Output> StreamParser<Output> for WebSocketSerdeParser
66where
67    Output: for<'de> Deserialize<'de>,
68{
69    type Stream = WebSocket;
70    type Message = WsMessage;
71    type Error = WsError;
72
73    fn parse(input: Result<Self::Message, Self::Error>) -> Option<Result<Output, SocketError>> {
74        match input {
75            Ok(ws_message) => match ws_message {
76                WsMessage::Text(text) => process_text(text),
77                WsMessage::Binary(binary) => process_binary(binary),
78                WsMessage::Ping(ping) => process_ping(ping),
79                WsMessage::Pong(pong) => process_pong(pong),
80                WsMessage::Close(close_frame) => process_close_frame(close_frame),
81                WsMessage::Frame(frame) => process_frame(frame),
82            },
83            Err(ws_err) => Some(Err(SocketError::WebSocket(Box::new(ws_err)))),
84        }
85    }
86}
87
88/// [`StreamParser`] implementation for a [`WebSocket`] that decodes protobuf
89/// binary payloads.
90#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Deserialize, Serialize)]
91pub struct WebSocketProtobufParser;
92
93impl<Output> StreamParser<Output> for WebSocketProtobufParser
94where
95    Output: prost::Message + Default,
96{
97    type Stream = WebSocket;
98    type Message = WsMessage;
99    type Error = WsError;
100
101    fn parse(input: Result<Self::Message, Self::Error>) -> Option<Result<Output, SocketError>> {
102        match input {
103            Ok(ws_message) => match ws_message {
104                WsMessage::Text(payload) => {
105                    debug!(?payload, "received Text WebSocket message");
106                    None
107                }
108                WsMessage::Binary(binary) => {
109                    Some(Output::decode(binary.as_ref()).map_err(|error| {
110                        SocketError::DeserialiseProtobuf {
111                            error,
112                            payload: binary.to_vec(),
113                        }
114                    }))
115                }
116                WsMessage::Ping(ping) => process_ping::<Output>(ping),
117                WsMessage::Pong(pong) => process_pong::<Output>(pong),
118                WsMessage::Close(close_frame) => process_close_frame::<Output>(close_frame),
119                WsMessage::Frame(frame) => process_frame::<Output>(frame),
120            },
121            Err(ws_err) => Some(Err(SocketError::WebSocket(Box::new(ws_err)))),
122        }
123    }
124}
125
126/// Process a payload of `String` by deserialising into an `ExchangeMessage`.
127pub fn process_text<ExchangeMessage>(
128    payload: Utf8Bytes,
129) -> Option<Result<ExchangeMessage, SocketError>>
130where
131    ExchangeMessage: for<'de> Deserialize<'de>,
132{
133    Some(
134        serde_json::from_str::<ExchangeMessage>(&payload).map_err(|error| {
135            debug!(
136                ?error,
137                ?payload,
138                action = "returning Some(Err(err))",
139                "failed to deserialize WebSocket Message into domain specific Message"
140            );
141            SocketError::Deserialise {
142                error,
143                payload: payload.to_string(),
144            }
145        }),
146    )
147}
148
149/// Process a payload of `Vec<u8>` bytes by deserialising into an `ExchangeMessage`.
150pub fn process_binary<ExchangeMessage>(
151    payload: Bytes,
152) -> Option<Result<ExchangeMessage, SocketError>>
153where
154    ExchangeMessage: for<'de> Deserialize<'de>,
155{
156    Some(
157        serde_json::from_slice::<ExchangeMessage>(&payload).map_err(|error| {
158            debug!(
159                ?error,
160                ?payload,
161                action = "returning Some(Err(err))",
162                "failed to deserialize WebSocket Message into domain specific Message"
163            );
164            SocketError::Deserialise {
165                error,
166                payload: String::from_utf8(payload.into()).unwrap_or_else(|x| x.to_string()),
167            }
168        }),
169    )
170}
171
172/// Basic process for a [`WebSocket`] ping message. Logs the payload at `trace` level.
173pub fn process_ping<ExchangeMessage>(ping: Bytes) -> Option<Result<ExchangeMessage, SocketError>> {
174    debug!(payload = ?ping, "received Ping WebSocket message");
175    None
176}
177
178/// Basic process for a [`WebSocket`] pong message. Logs the payload at `trace` level.
179pub fn process_pong<ExchangeMessage>(pong: Bytes) -> Option<Result<ExchangeMessage, SocketError>> {
180    debug!(payload = ?pong, "received Pong WebSocket message");
181    None
182}
183
184/// Basic process for a [`WebSocket`] CloseFrame message. Logs the payload at `trace` level.
185pub fn process_close_frame<ExchangeMessage>(
186    close_frame: Option<CloseFrame>,
187) -> Option<Result<ExchangeMessage, SocketError>> {
188    let close_frame = format!("{close_frame:?}");
189    debug!(payload = %close_frame, "received CloseFrame WebSocket message");
190    Some(Err(SocketError::Terminated(close_frame)))
191}
192
193/// Basic process for a [`WebSocket`] Frame message. Logs the payload at `trace` level.
194pub fn process_frame<ExchangeMessage>(
195    frame: Frame,
196) -> Option<Result<ExchangeMessage, SocketError>> {
197    let frame = format!("{frame:?}");
198    debug!(payload = %frame, "received unexpected Frame WebSocket message");
199    None
200}
201
202/// Connect asynchronously to a [`WebSocket`] server.
203pub async fn connect<R>(request: R) -> Result<WebSocket, SocketError>
204where
205    R: IntoClientRequest + Unpin + Debug,
206{
207    debug!(?request, "attempting to establish WebSocket connection");
208    connect_async(request)
209        .await
210        .map(|(websocket, _)| websocket)
211        .map_err(|error| SocketError::WebSocket(Box::new(error)))
212}
213
214/// Determine whether a [`WsError`] indicates the [`WebSocket`] has disconnected.
215pub fn is_websocket_disconnected(error: &WsError) -> bool {
216    matches!(
217        error,
218        WsError::ConnectionClosed
219            | WsError::AlreadyClosed
220            | WsError::Io(_)
221            | WsError::Protocol(ProtocolError::SendAfterClosing)
222    )
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
229
230    #[test]
231    fn test_ws_parser_text_message() {
232        let msg = Ok(WsMessage::Text("hello".into()));
233        let result = WsParser::parse(msg);
234        assert!(matches!(result, Message::Payload(bytes) if bytes == Bytes::from("hello")));
235    }
236
237    #[test]
238    fn test_ws_parser_binary_message() {
239        let msg = Ok(WsMessage::Binary(Bytes::from_static(b"\x01\x02")));
240        let result = WsParser::parse(msg);
241        assert!(matches!(result, Message::Payload(bytes) if bytes == Bytes::from_static(b"\x01\x02")));
242    }
243
244    #[test]
245    fn test_ws_parser_ping() {
246        let msg = Ok(WsMessage::Ping(Bytes::from_static(b"ping")));
247        let result = WsParser::parse(msg);
248        assert!(matches!(result, Message::Admin(AdminWs::Ping(bytes)) if bytes == Bytes::from_static(b"ping")));
249    }
250
251    #[test]
252    fn test_ws_parser_pong() {
253        let msg = Ok(WsMessage::Pong(Bytes::from_static(b"pong")));
254        let result = WsParser::parse(msg);
255        assert!(matches!(result, Message::Admin(AdminWs::Pong(bytes)) if bytes == Bytes::from_static(b"pong")));
256    }
257
258    #[test]
259    fn test_ws_parser_close() {
260        let close = CloseFrame {
261            code: CloseCode::Normal,
262            reason: "bye".into(),
263        };
264        let msg = Ok(WsMessage::Close(Some(close)));
265        let result = WsParser::parse(msg);
266        assert!(matches!(result, Message::Admin(AdminWs::Close(Some(_)))));
267    }
268
269    #[test]
270    fn test_ws_parser_error() {
271        let msg = Err(WsError::ConnectionClosed);
272        let result = WsParser::parse(msg);
273        assert!(matches!(result, Message::Admin(AdminWs::WsError(_))));
274    }
275}