barter_integration/protocol/websocket/
mod.rs1use crate::{Message, error::SocketError, protocol::StreamParser};
2use bytes::Bytes;
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use tokio::net::TcpStream;
6use tokio_tungstenite::{
7 MaybeTlsStream, connect_async,
8 tungstenite::{
9 Utf8Bytes,
10 client::IntoClientRequest,
11 error::ProtocolError,
12 protocol::{CloseFrame, frame::Frame},
13 },
14};
15use tracing::debug;
16
17pub type WebSocket = tokio_tungstenite::WebSocketStream<MaybeTlsStream<TcpStream>>;
19
20pub type WsSink = futures::stream::SplitSink<WebSocket, WsMessage>;
22
23pub type WsStream = futures::stream::SplitStream<WebSocket>;
25
26pub type WsMessage = tokio_tungstenite::tungstenite::Message;
28
29pub type WsError = tokio_tungstenite::tungstenite::Error;
31
32#[derive(Debug)]
34pub enum AdminWs {
35 Ping(Bytes),
36 Pong(Bytes),
37 Close(Option<CloseFrame>),
38 WsError(WsError),
39}
40
41#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Deserialize, Serialize)]
45pub struct WsParser;
46
47impl WsParser {
48 pub fn parse(ws_result: Result<WsMessage, WsError>) -> Message<AdminWs, Bytes> {
49 match ws_result {
50 Ok(WsMessage::Text(utf8)) => Message::Payload(Bytes::from(utf8)),
51 Ok(WsMessage::Binary(bytes)) => Message::Payload(bytes),
52 Ok(WsMessage::Frame(frame)) => Message::Payload(frame.into_payload()),
53 Ok(WsMessage::Ping(bytes)) => Message::Admin(AdminWs::Ping(bytes)),
54 Ok(WsMessage::Pong(bytes)) => Message::Admin(AdminWs::Pong(bytes)),
55 Ok(WsMessage::Close(close)) => Message::Admin(AdminWs::Close(close)),
56 Err(error) => Message::Admin(AdminWs::WsError(error)),
57 }
58 }
59}
60
61#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Deserialize, Serialize)]
63pub struct WebSocketSerdeParser;
64
65impl<Output> StreamParser<Output> for WebSocketSerdeParser
66where
67 Output: for<'de> Deserialize<'de>,
68{
69 type Stream = WebSocket;
70 type Message = WsMessage;
71 type Error = WsError;
72
73 fn parse(input: Result<Self::Message, Self::Error>) -> Option<Result<Output, SocketError>> {
74 match input {
75 Ok(ws_message) => match ws_message {
76 WsMessage::Text(text) => process_text(text),
77 WsMessage::Binary(binary) => process_binary(binary),
78 WsMessage::Ping(ping) => process_ping(ping),
79 WsMessage::Pong(pong) => process_pong(pong),
80 WsMessage::Close(close_frame) => process_close_frame(close_frame),
81 WsMessage::Frame(frame) => process_frame(frame),
82 },
83 Err(ws_err) => Some(Err(SocketError::WebSocket(Box::new(ws_err)))),
84 }
85 }
86}
87
88#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Deserialize, Serialize)]
91pub struct WebSocketProtobufParser;
92
93impl<Output> StreamParser<Output> for WebSocketProtobufParser
94where
95 Output: prost::Message + Default,
96{
97 type Stream = WebSocket;
98 type Message = WsMessage;
99 type Error = WsError;
100
101 fn parse(input: Result<Self::Message, Self::Error>) -> Option<Result<Output, SocketError>> {
102 match input {
103 Ok(ws_message) => match ws_message {
104 WsMessage::Text(payload) => {
105 debug!(?payload, "received Text WebSocket message");
106 None
107 }
108 WsMessage::Binary(binary) => {
109 Some(Output::decode(binary.as_ref()).map_err(|error| {
110 SocketError::DeserialiseProtobuf {
111 error,
112 payload: binary.to_vec(),
113 }
114 }))
115 }
116 WsMessage::Ping(ping) => process_ping::<Output>(ping),
117 WsMessage::Pong(pong) => process_pong::<Output>(pong),
118 WsMessage::Close(close_frame) => process_close_frame::<Output>(close_frame),
119 WsMessage::Frame(frame) => process_frame::<Output>(frame),
120 },
121 Err(ws_err) => Some(Err(SocketError::WebSocket(Box::new(ws_err)))),
122 }
123 }
124}
125
126pub fn process_text<ExchangeMessage>(
128 payload: Utf8Bytes,
129) -> Option<Result<ExchangeMessage, SocketError>>
130where
131 ExchangeMessage: for<'de> Deserialize<'de>,
132{
133 Some(
134 serde_json::from_str::<ExchangeMessage>(&payload).map_err(|error| {
135 debug!(
136 ?error,
137 ?payload,
138 action = "returning Some(Err(err))",
139 "failed to deserialize WebSocket Message into domain specific Message"
140 );
141 SocketError::Deserialise {
142 error,
143 payload: payload.to_string(),
144 }
145 }),
146 )
147}
148
149pub fn process_binary<ExchangeMessage>(
151 payload: Bytes,
152) -> Option<Result<ExchangeMessage, SocketError>>
153where
154 ExchangeMessage: for<'de> Deserialize<'de>,
155{
156 Some(
157 serde_json::from_slice::<ExchangeMessage>(&payload).map_err(|error| {
158 debug!(
159 ?error,
160 ?payload,
161 action = "returning Some(Err(err))",
162 "failed to deserialize WebSocket Message into domain specific Message"
163 );
164 SocketError::Deserialise {
165 error,
166 payload: String::from_utf8(payload.into()).unwrap_or_else(|x| x.to_string()),
167 }
168 }),
169 )
170}
171
172pub fn process_ping<ExchangeMessage>(ping: Bytes) -> Option<Result<ExchangeMessage, SocketError>> {
174 debug!(payload = ?ping, "received Ping WebSocket message");
175 None
176}
177
178pub fn process_pong<ExchangeMessage>(pong: Bytes) -> Option<Result<ExchangeMessage, SocketError>> {
180 debug!(payload = ?pong, "received Pong WebSocket message");
181 None
182}
183
184pub fn process_close_frame<ExchangeMessage>(
186 close_frame: Option<CloseFrame>,
187) -> Option<Result<ExchangeMessage, SocketError>> {
188 let close_frame = format!("{close_frame:?}");
189 debug!(payload = %close_frame, "received CloseFrame WebSocket message");
190 Some(Err(SocketError::Terminated(close_frame)))
191}
192
193pub fn process_frame<ExchangeMessage>(
195 frame: Frame,
196) -> Option<Result<ExchangeMessage, SocketError>> {
197 let frame = format!("{frame:?}");
198 debug!(payload = %frame, "received unexpected Frame WebSocket message");
199 None
200}
201
202pub async fn connect<R>(request: R) -> Result<WebSocket, SocketError>
204where
205 R: IntoClientRequest + Unpin + Debug,
206{
207 debug!(?request, "attempting to establish WebSocket connection");
208 connect_async(request)
209 .await
210 .map(|(websocket, _)| websocket)
211 .map_err(|error| SocketError::WebSocket(Box::new(error)))
212}
213
214pub fn is_websocket_disconnected(error: &WsError) -> bool {
216 matches!(
217 error,
218 WsError::ConnectionClosed
219 | WsError::AlreadyClosed
220 | WsError::Io(_)
221 | WsError::Protocol(ProtocolError::SendAfterClosing)
222 )
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
229
230 #[test]
231 fn test_ws_parser_text_message() {
232 let msg = Ok(WsMessage::Text("hello".into()));
233 let result = WsParser::parse(msg);
234 assert!(matches!(result, Message::Payload(bytes) if bytes == Bytes::from("hello")));
235 }
236
237 #[test]
238 fn test_ws_parser_binary_message() {
239 let msg = Ok(WsMessage::Binary(Bytes::from_static(b"\x01\x02")));
240 let result = WsParser::parse(msg);
241 assert!(matches!(result, Message::Payload(bytes) if bytes == Bytes::from_static(b"\x01\x02")));
242 }
243
244 #[test]
245 fn test_ws_parser_ping() {
246 let msg = Ok(WsMessage::Ping(Bytes::from_static(b"ping")));
247 let result = WsParser::parse(msg);
248 assert!(matches!(result, Message::Admin(AdminWs::Ping(bytes)) if bytes == Bytes::from_static(b"ping")));
249 }
250
251 #[test]
252 fn test_ws_parser_pong() {
253 let msg = Ok(WsMessage::Pong(Bytes::from_static(b"pong")));
254 let result = WsParser::parse(msg);
255 assert!(matches!(result, Message::Admin(AdminWs::Pong(bytes)) if bytes == Bytes::from_static(b"pong")));
256 }
257
258 #[test]
259 fn test_ws_parser_close() {
260 let close = CloseFrame {
261 code: CloseCode::Normal,
262 reason: "bye".into(),
263 };
264 let msg = Ok(WsMessage::Close(Some(close)));
265 let result = WsParser::parse(msg);
266 assert!(matches!(result, Message::Admin(AdminWs::Close(Some(_)))));
267 }
268
269 #[test]
270 fn test_ws_parser_error() {
271 let msg = Err(WsError::ConnectionClosed);
272 let result = WsParser::parse(msg);
273 assert!(matches!(result, Message::Admin(AdminWs::WsError(_))));
274 }
275}