use crate::{error::SocketError, protocol::StreamParser};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fmt::Debug;
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async,
tungstenite::{
client::IntoClientRequest,
error::ProtocolError,
protocol::{frame::Frame, CloseFrame},
},
MaybeTlsStream,
};
use tracing::debug;
pub type WebSocket = tokio_tungstenite::WebSocketStream<MaybeTlsStream<TcpStream>>;
pub type WsSink = futures::stream::SplitSink<WebSocket, WsMessage>;
pub type WsStream = futures::stream::SplitStream<WebSocket>;
pub type WsMessage = tokio_tungstenite::tungstenite::Message;
pub type WsError = tokio_tungstenite::tungstenite::Error;
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Deserialize, Serialize)]
pub struct WebSocketParser;
impl StreamParser for WebSocketParser {
type Stream = WebSocket;
type Message = WsMessage;
type Error = WsError;
fn parse<Output>(
input: Result<Self::Message, Self::Error>,
) -> Option<Result<Output, SocketError>>
where
Output: DeserializeOwned,
{
match input {
Ok(ws_message) => match ws_message {
WsMessage::Text(text) => process_text(text),
WsMessage::Binary(binary) => process_binary(binary),
WsMessage::Ping(ping) => process_ping(ping),
WsMessage::Pong(pong) => process_pong(pong),
WsMessage::Close(close_frame) => process_close_frame(close_frame),
WsMessage::Frame(frame) => process_frame(frame),
},
Err(ws_err) => Some(Err(SocketError::WebSocket(ws_err))),
}
}
}
pub fn process_text<ExchangeMessage>(
payload: String,
) -> Option<Result<ExchangeMessage, SocketError>>
where
ExchangeMessage: DeserializeOwned,
{
Some(
serde_json::from_str::<ExchangeMessage>(&payload).map_err(|error| {
debug!(
?error,
?payload,
action = "returning Some(Err(err))",
"failed to deserialize WebSocket Message into domain specific Message"
);
SocketError::Deserialise { error, payload }
}),
)
}
pub fn process_binary<ExchangeMessage>(
payload: Vec<u8>,
) -> Option<Result<ExchangeMessage, SocketError>>
where
ExchangeMessage: DeserializeOwned,
{
Some(
serde_json::from_slice::<ExchangeMessage>(&payload).map_err(|error| {
debug!(
?error,
?payload,
action = "returning Some(Err(err))",
"failed to deserialize WebSocket Message into domain specific Message"
);
SocketError::Deserialise {
error,
payload: String::from_utf8(payload).unwrap_or_else(|x| x.to_string()),
}
}),
)
}
pub fn process_ping<ExchangeMessage>(
ping: Vec<u8>,
) -> Option<Result<ExchangeMessage, SocketError>> {
debug!(payload = ?ping, "received Ping WebSocket message");
None
}
pub fn process_pong<ExchangeMessage>(
pong: Vec<u8>,
) -> Option<Result<ExchangeMessage, SocketError>> {
debug!(payload = ?pong, "received Pong WebSocket message");
None
}
pub fn process_close_frame<ExchangeMessage>(
close_frame: Option<CloseFrame<'_>>,
) -> Option<Result<ExchangeMessage, SocketError>> {
let close_frame = format!("{:?}", close_frame);
debug!(payload = %close_frame, "received CloseFrame WebSocket message");
Some(Err(SocketError::Terminated(close_frame)))
}
pub fn process_frame<ExchangeMessage>(
frame: Frame,
) -> Option<Result<ExchangeMessage, SocketError>> {
let frame = format!("{:?}", frame);
debug!(payload = %frame, "received unexpected Frame WebSocket message");
None
}
pub async fn connect<R>(request: R) -> Result<WebSocket, SocketError>
where
R: IntoClientRequest + Unpin + Debug,
{
debug!(?request, "attempting to establish WebSocket connection");
connect_async(request)
.await
.map(|(websocket, _)| websocket)
.map_err(SocketError::WebSocket)
}
pub fn is_websocket_disconnected(error: &WsError) -> bool {
matches!(
error,
WsError::ConnectionClosed
| WsError::AlreadyClosed
| WsError::Io(_)
| WsError::Protocol(ProtocolError::SendAfterClosing)
)
}