use bytes::Bytes;
use std::task::{Context, Poll};
#[derive(Clone, PartialEq, Eq)]
pub enum Message {
Binary(Bytes),
Ping,
Pong,
Close,
}
impl std::fmt::Debug for Message {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Binary(data) => f.debug_struct("Binary").field("len", &data.len()).finish(),
Self::Ping => f.debug_struct("Ping").finish(),
Self::Pong => f.debug_struct("Pong").finish(),
Self::Close => f.debug_struct("Close").finish(),
}
}
}
pub trait WebSocket: Send + 'static {
fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>>;
fn start_send_unpin(&mut self, item: Message) -> Result<(), crate::Error>;
fn poll_flush_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>>;
fn poll_close_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>>;
fn poll_next_unpin(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Message, crate::Error>>>;
}
#[cfg(feature = "tungstenite")]
mod tokio_tungstenite {
use super::{Message, WebSocket};
use bytes::Bytes;
use futures_util::{Sink, Stream};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::tungstenite;
use tracing::error;
impl From<tungstenite::Message> for Message {
#[inline]
fn from(msg: tungstenite::Message) -> Self {
match msg {
tungstenite::Message::Binary(data) => Self::Binary(data),
tungstenite::Message::Text(text) => {
error!("Received text message: {text}");
Self::Binary(Bytes::from(text))
}
tungstenite::Message::Ping(_) => Self::Ping,
tungstenite::Message::Pong(_) => Self::Pong,
tungstenite::Message::Close(_) => Self::Close,
tungstenite::Message::Frame(_) => {
unreachable!("`Frame` message should not be received")
}
}
}
}
impl From<Message> for tungstenite::Message {
#[inline]
fn from(msg: Message) -> Self {
match msg {
Message::Binary(data) => Self::Binary(data),
Message::Ping => Self::Ping(Bytes::new()),
Message::Pong => Self::Pong(Bytes::new()),
Message::Close => Self::Close(None),
}
}
}
impl From<tungstenite::Error> for crate::Error {
#[inline]
fn from(e: tungstenite::Error) -> Self {
Self::WebSocket(Box::new(e))
}
}
impl<RW> WebSocket for tokio_tungstenite::WebSocketStream<RW>
where
RW: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
#[inline]
fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>> {
Pin::new(self).poll_ready(cx).map_err(Into::into)
}
#[inline]
fn start_send_unpin(&mut self, item: Message) -> Result<(), crate::Error> {
let item: tungstenite::Message = item.into();
Pin::new(self).start_send(item).map_err(Into::into)
}
#[inline]
fn poll_flush_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>> {
Pin::new(self).poll_flush(cx).map_err(Into::into)
}
#[inline]
fn poll_close_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>> {
let this = Pin::new(self);
futures_util::Sink::poll_close(this, cx).map_err(Into::into)
}
#[inline]
fn poll_next_unpin(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Message, crate::Error>>> {
Pin::new(self)
.poll_next(cx)
.map(|opt| opt.map(|res| res.map(Into::into).map_err(Into::into)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
#[test]
fn test_binary_message() {
let msg = tungstenite::Message::Binary(Bytes::from_static(b"Hello"));
let converted: Message = msg.clone().into();
assert_eq!(converted, Message::Binary(Bytes::from_static(b"Hello")));
assert_eq!(tungstenite::Message::from(converted), msg);
}
#[test]
fn test_text_message() {
let msg = tungstenite::Message::Text("Hello".into());
let converted: Message = msg.into();
assert_eq!(converted, Message::Binary(Bytes::from_static(b"Hello")));
assert_eq!(
tungstenite::Message::from(converted),
tungstenite::Message::Binary(Bytes::from_static(b"Hello"))
);
}
#[test]
fn test_ping_message() {
let msg = tungstenite::Message::Ping(Bytes::from_static(b"Ping"));
let converted: Message = msg.into();
assert_eq!(converted, Message::Ping);
assert_eq!(
tungstenite::Message::from(converted),
tungstenite::Message::Ping(Bytes::new())
);
let msg = tungstenite::Message::Pong(Bytes::from_static(b"Pong"));
let converted: Message = msg.into();
assert_eq!(converted, Message::Pong);
assert_eq!(
tungstenite::Message::from(converted),
tungstenite::Message::Pong(Bytes::new())
);
}
#[test]
fn test_close_message() {
let close_msg =
tungstenite::Message::Close(Some(tungstenite::protocol::frame::CloseFrame {
code: CloseCode::Reserved(1000),
reason: "Normal".into(),
}));
let converted: Message = close_msg.into();
assert_eq!(converted, Message::Close);
}
}
}