use super::Message;
use crate::{error::Error, headers::HeaderValue};
use futures_util::{
sink::{Sink, SinkExt},
stream::{SplitSink, SplitStream, Stream, StreamExt},
};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll, ready},
};
use tokio_tungstenite::tungstenite::{
Error as WsError, Message as WsMessage, protocol::CloseFrame,
};
use tokio_tungstenite::{WebSocketStream, tungstenite};
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<TokioIo<Upgraded>>,
protocol: Option<HeaderValue>,
}
pub struct WsSink(SplitSink<WebSocketStream<TokioIo<Upgraded>>, WsMessage>);
pub struct WsStream(SplitStream<WebSocketStream<TokioIo<Upgraded>>>);
#[derive(Debug)]
pub enum WsEvent<T> {
Data(T),
Close(Option<CloseFrame>),
}
impl std::fmt::Debug for WsSink {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("WsSink(..)")
}
}
impl std::fmt::Debug for WsStream {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("WsStream(..)")
}
}
impl WsSink {
#[inline]
pub fn into_inner(self) -> SplitSink<WebSocketStream<TokioIo<Upgraded>>, WsMessage> {
self.0
}
#[inline]
pub async fn send<T: TryInto<Message, Error = Error>>(&mut self, msg: T) -> Result<(), Error> {
let msg = msg.try_into()?.into();
self.0.send(msg).await.map_err(Error::from)
}
#[inline]
pub async fn close(&mut self) -> Result<(), Error> {
match self.0.close().await {
Ok(()) => Ok(()),
Err(e) if is_expected_close_error(&e) => Ok(()),
Err(e) => Err(Error::from(e)),
}
}
pub async fn send_close(&mut self, frame: Option<CloseFrame>) -> Result<(), Error> {
match self.0.send(WsMessage::Close(frame)).await {
Ok(()) => Ok(()),
Err(e) if is_expected_close_error(&e) => Ok(()),
Err(e) => Err(Error::from(e)),
}
}
}
impl WsStream {
#[inline]
pub fn into_inner(self) -> SplitStream<WebSocketStream<TokioIo<Upgraded>>> {
self.0
}
pub async fn recv<T>(&mut self) -> Option<Result<WsEvent<T>, Error>>
where
T: TryFrom<Message, Error = Error>,
{
loop {
let msg = match self.recv_raw().await? {
Ok(msg) => msg,
Err(err) => return Some(Err(err)),
};
match msg.0 {
WsMessage::Ping(_) | WsMessage::Pong(_) => continue,
WsMessage::Close(frame) => return Some(Ok(WsEvent::Close(frame))),
WsMessage::Text(_) | WsMessage::Binary(_) => {
return Some(T::try_from(msg).map(WsEvent::Data));
}
WsMessage::Frame(_) => {
debug_assert!(
false,
"tungstenite returned a raw Frame while reading messages"
);
continue;
}
}
}
}
#[inline]
async fn recv_raw(&mut self) -> Option<Result<Message, Error>> {
recv_raw_from(&mut self.0).await
}
}
impl WebSocket {
#[inline]
pub(super) fn new(
inner: WebSocketStream<TokioIo<Upgraded>>,
protocol: Option<HeaderValue>,
) -> Self {
Self { inner, protocol }
}
pub async fn recv<T>(&mut self) -> Option<Result<T, Error>>
where
T: TryFrom<Message, Error = Error>,
{
loop {
let msg = match self.recv_raw().await? {
Ok(msg) => msg,
Err(err) => return Some(Err(err)),
};
match msg.0 {
WsMessage::Ping(_) | WsMessage::Pong(_) => continue,
WsMessage::Text(_) | WsMessage::Binary(_) => return Some(T::try_from(msg)),
WsMessage::Frame(_) => {
debug_assert!(
false,
"tungstenite returned a raw Frame while reading messages"
);
continue;
}
WsMessage::Close(_) => {
if let Err(_close_err) = self.finish_close().await {
#[cfg(feature = "tracing")]
tracing::warn!("WebSocket close failed: {_close_err}");
}
return None;
}
}
}
}
#[inline]
pub async fn send<T: TryInto<Message, Error = Error>>(&mut self, msg: T) -> Result<(), Error> {
let msg = msg.try_into()?;
self.inner.send(msg.into_inner()).await.map_err(Error::from)
}
pub fn protocol(&self) -> Option<&HeaderValue> {
self.protocol.as_ref()
}
#[inline]
pub fn split(self) -> (WsSink, WsStream) {
let (tx, rx) = self.inner.split();
(WsSink(tx), WsStream(rx))
}
#[inline]
pub async fn on_msg<F, M, R, Fut>(&mut self, handler: F)
where
F: Fn(M) -> Fut + Send + 'static,
M: TryFrom<Message, Error = Error>,
R: TryInto<Message, Error = Error>,
Fut: Future<Output = R> + Send,
{
while let Some(msg) = self.recv::<M>().await {
let msg = match msg {
Ok(msg) => msg,
Err(_e) => {
#[cfg(feature = "tracing")]
tracing::error!("Error receiving message: {_e}");
continue;
}
};
let response = handler(msg).await;
if let Err(_e) = self.send(response).await {
#[cfg(feature = "tracing")]
tracing::error!("Error sending message: {_e}");
if let Err(_close_err) = self.finish_close().await {
#[cfg(feature = "tracing")]
tracing::warn!("WebSocket close failed: {_close_err}");
}
return;
}
}
}
#[inline]
pub async fn close(&mut self, frame: Option<CloseFrame>) -> Result<(), Error> {
match self.inner.close(frame).await {
Ok(()) => Ok(()),
Err(e) if is_expected_close_error(&e) => Ok(()),
Err(e) => Err(Error::from(e)),
}
}
#[inline]
async fn finish_close(&mut self) -> Result<(), Error> {
match SinkExt::close(&mut self.inner).await {
Ok(()) => Ok(()),
Err(e) if is_expected_close_error(&e) => Ok(()),
Err(e) => Err(Error::from(e)),
}
}
#[inline]
async fn recv_raw(&mut self) -> Option<Result<Message, Error>> {
recv_raw_from(&mut self.inner).await
}
}
#[inline]
async fn recv_raw_from<S>(stream: &mut S) -> Option<Result<Message, Error>>
where
S: Stream<Item = Result<WsMessage, tungstenite::Error>> + Unpin,
{
stream
.next()
.await
.map(|r| r.map(Message).map_err(Error::from))
}
#[inline]
fn is_expected_close_error(e: &WsError) -> bool {
match e {
WsError::ConnectionClosed => true,
WsError::AlreadyClosed => true,
WsError::Protocol(p) => matches!(p, tungstenite::error::ProtocolError::SendAfterClosing),
WsError::Io(io) => matches!(
io.kind(),
std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::NotConnected
),
_ => false,
}
}
impl Stream for WebSocket {
type Item = Result<Message, Error>;
#[inline]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match ready!(self.inner.poll_next_unpin(cx)) {
None => return Poll::Ready(None),
Some(Err(err)) => return Poll::Ready(Some(Err(err.into()))),
Some(Ok(msg)) => {
let WsMessage::Frame(_) = msg else {
return Poll::Ready(Some(Ok(Message(msg))));
};
}
}
}
}
}
impl Sink<Message> for WebSocket {
type Error = Error;
#[inline]
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match ready!(Pin::new(&mut self.inner).poll_ready(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(Err(Error::server_error(e))),
}
}
#[inline]
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match Pin::new(&mut self.inner).start_send(item.0) {
Ok(_) => Ok(()),
Err(err) => Err(Error::server_error(err)),
}
}
#[inline]
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match ready!(Pin::new(&mut self.inner).poll_flush(cx)) {
Ok(_) => Poll::Ready(Ok(())),
Err(err) => Poll::Ready(Err(Error::server_error(err))),
}
}
#[inline]
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match ready!(Pin::new(&mut self.inner).poll_close(cx)) {
Ok(_) => Poll::Ready(Ok(())),
Err(err) => Poll::Ready(Err(Error::server_error(err))),
}
}
}