barter_integration/protocol/
websocket.rs1use crate::{error::SocketError, protocol::StreamParser};
2use bytes::Bytes;
3use serde::{Deserialize, Serialize, de::DeserializeOwned};
4use std::fmt::Debug;
5use tokio::net::TcpStream;
6use tokio_tungstenite::{
7 MaybeTlsStream, connect_async,
8 tungstenite::{
9 Utf8Bytes,
10 client::IntoClientRequest,
11 error::ProtocolError,
12 protocol::{CloseFrame, frame::Frame},
13 },
14};
15use tracing::debug;
16
17pub type WebSocket = tokio_tungstenite::WebSocketStream<MaybeTlsStream<TcpStream>>;
19
20pub type WsSink = futures::stream::SplitSink<WebSocket, WsMessage>;
22
23pub type WsStream = futures::stream::SplitStream<WebSocket>;
25
26pub type WsMessage = tokio_tungstenite::tungstenite::Message;
28
29pub type WsError = tokio_tungstenite::tungstenite::Error;
31
32#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Deserialize, Serialize)]
34pub struct WebSocketParser;
35
36impl StreamParser for WebSocketParser {
37 type Stream = WebSocket;
38 type Message = WsMessage;
39 type Error = WsError;
40
41 fn parse<Output>(
42 input: Result<Self::Message, Self::Error>,
43 ) -> Option<Result<Output, SocketError>>
44 where
45 Output: DeserializeOwned,
46 {
47 match input {
48 Ok(ws_message) => match ws_message {
49 WsMessage::Text(text) => process_text(text),
50 WsMessage::Binary(binary) => process_binary(binary),
51 WsMessage::Ping(ping) => process_ping(ping),
52 WsMessage::Pong(pong) => process_pong(pong),
53 WsMessage::Close(close_frame) => process_close_frame(close_frame),
54 WsMessage::Frame(frame) => process_frame(frame),
55 },
56 Err(ws_err) => Some(Err(SocketError::WebSocket(Box::new(ws_err)))),
57 }
58 }
59}
60
61pub fn process_text<ExchangeMessage>(
63 payload: Utf8Bytes,
64) -> Option<Result<ExchangeMessage, SocketError>>
65where
66 ExchangeMessage: DeserializeOwned,
67{
68 Some(
69 serde_json::from_str::<ExchangeMessage>(&payload).map_err(|error| {
70 debug!(
71 ?error,
72 ?payload,
73 action = "returning Some(Err(err))",
74 "failed to deserialize WebSocket Message into domain specific Message"
75 );
76 SocketError::Deserialise {
77 error,
78 payload: payload.to_string(),
79 }
80 }),
81 )
82}
83
84pub fn process_binary<ExchangeMessage>(
86 payload: Bytes,
87) -> Option<Result<ExchangeMessage, SocketError>>
88where
89 ExchangeMessage: DeserializeOwned,
90{
91 Some(
92 serde_json::from_slice::<ExchangeMessage>(&payload).map_err(|error| {
93 debug!(
94 ?error,
95 ?payload,
96 action = "returning Some(Err(err))",
97 "failed to deserialize WebSocket Message into domain specific Message"
98 );
99 SocketError::Deserialise {
100 error,
101 payload: String::from_utf8(payload.into()).unwrap_or_else(|x| x.to_string()),
102 }
103 }),
104 )
105}
106
107pub fn process_ping<ExchangeMessage>(ping: Bytes) -> Option<Result<ExchangeMessage, SocketError>> {
109 debug!(payload = ?ping, "received Ping WebSocket message");
110 None
111}
112
113pub fn process_pong<ExchangeMessage>(pong: Bytes) -> Option<Result<ExchangeMessage, SocketError>> {
115 debug!(payload = ?pong, "received Pong WebSocket message");
116 None
117}
118
119pub fn process_close_frame<ExchangeMessage>(
121 close_frame: Option<CloseFrame>,
122) -> Option<Result<ExchangeMessage, SocketError>> {
123 let close_frame = format!("{close_frame:?}");
124 debug!(payload = %close_frame, "received CloseFrame WebSocket message");
125 Some(Err(SocketError::Terminated(close_frame)))
126}
127
128pub fn process_frame<ExchangeMessage>(
130 frame: Frame,
131) -> Option<Result<ExchangeMessage, SocketError>> {
132 let frame = format!("{frame:?}");
133 debug!(payload = %frame, "received unexpected Frame WebSocket message");
134 None
135}
136
137pub async fn connect<R>(request: R) -> Result<WebSocket, SocketError>
139where
140 R: IntoClientRequest + Unpin + Debug,
141{
142 debug!(?request, "attempting to establish WebSocket connection");
143 connect_async(request)
144 .await
145 .map(|(websocket, _)| websocket)
146 .map_err(|error| SocketError::WebSocket(Box::new(error)))
147}
148
149pub fn is_websocket_disconnected(error: &WsError) -> bool {
151 matches!(
152 error,
153 WsError::ConnectionClosed
154 | WsError::AlreadyClosed
155 | WsError::Io(_)
156 | WsError::Protocol(ProtocolError::SendAfterClosing)
157 )
158}