use std::fmt::Debug;
use std::fmt::Formatter;
use std::fmt::Result as FmtResult;
use std::io;
use std::marker::PhantomData;
use std::pin::Pin;
use std::str::from_utf8 as str_from_utf8;
use std::task::Poll as StdPoll;
use std::time::Duration;
use futures::task::Context;
use futures::task::Poll;
use futures::Sink;
use futures::SinkExt as _;
use futures::Stream;
use futures::StreamExt as _;
use tokio::time::interval;
use tokio::time::Interval;
use tokio_tungstenite::tungstenite::Error as WebSocketError;
use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
use tracing::debug;
use tracing::error;
use tracing::field::debug;
use tracing::field::DebugValue;
use tracing::trace;
#[derive(Clone, Copy, Debug)]
enum Ping {
NotNeeded,
Needed,
Pending,
}
#[derive(Debug, PartialEq)]
pub enum Message {
Text(String),
Binary(Vec<u8>),
}
impl From<Message> for WebSocketMessage {
fn from(message: Message) -> Self {
match message {
Message::Text(data) => WebSocketMessage::Text(data),
Message::Binary(data) => WebSocketMessage::Binary(data),
}
}
}
#[derive(Debug)]
enum SendMessageState<M> {
Unused,
Pending(Option<M>),
Flush,
}
impl<M> SendMessageState<M> {
fn advance<S>(&mut self, sink: &mut S, ctx: &mut Context<'_>) -> Result<(), S::Error>
where
S: Sink<M> + Unpin,
M: Debug,
{
match self {
Self::Unused => Ok(()),
Self::Pending(message) => {
match sink.poll_ready_unpin(ctx) {
Poll::Pending => return Ok(()),
Poll::Ready(Ok(())) => (),
Poll::Ready(Err(err)) => {
*self = Self::Unused;
return Err(err)
},
}
let message = message.take();
*self = Self::Unused;
debug!(
channel = debug(sink as *const _),
send_msg = debug(&message)
);
if let Some(message) = message {
sink.start_send_unpin(message)?;
*self = Self::Flush;
}
Ok(())
},
Self::Flush => {
trace!(channel = debug(sink as *const _), msg = "flushing");
*self = Self::Unused;
if let Poll::Ready(Err(err)) = sink.poll_flush_unpin(ctx) {
Err(err)
} else {
Ok(())
}
},
}
}
fn set(&mut self, message: M) {
*self = Self::Pending(Some(message))
}
}
fn set_message<S, M>(channel: &S, message_state: &mut SendMessageState<M>, message: M)
where
M: Debug,
{
match message_state {
SendMessageState::Unused => (),
SendMessageState::Pending(old_message) => {
debug!(
channel = debug(channel as *const _),
send_msg_old = debug(&old_message),
send_msg_new = debug(&message),
msg = "message overrun; last message has not been sent"
);
},
SendMessageState::Flush => {
debug!(
channel = debug(channel as *const _),
msg = "message overrun; last message has not been flushed"
);
},
}
message_state.set(message);
}
struct DebugMessage<'m> {
message: &'m WebSocketMessage,
}
impl<'m> Debug for DebugMessage<'m> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self.message {
WebSocketMessage::Binary(data) => {
if let Ok(s) = str_from_utf8(data) {
f.debug_tuple("Binary").field(&s).finish()
} else {
Debug::fmt(self.message, f)
}
},
_ => Debug::fmt(self.message, f),
}
}
}
fn debug_message(message: &WebSocketMessage) -> DebugValue<DebugMessage<'_>> {
debug(DebugMessage { message })
}
#[derive(Debug)]
struct Pinger {
ping: SendMessageState<WebSocketMessage>,
next_ping: Interval,
ping_state: Ping,
}
impl Pinger {
fn new(ping_interval: Duration) -> Self {
Self {
ping: SendMessageState::Unused,
next_ping: interval(ping_interval),
ping_state: Ping::NotNeeded,
}
}
#[allow(clippy::result_large_err)]
fn advance<S>(&mut self, sink: &mut S, ctx: &mut Context<'_>) -> Result<(), S::Error>
where
S: Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
{
self.ping.advance(sink, ctx)?;
match self.next_ping.poll_tick(ctx) {
StdPoll::Ready(_) => {
self.ping_state = match self.ping_state {
Ping::NotNeeded => {
trace!(
channel = debug(sink as *const _),
msg = "skipping ping due to activity"
);
Ping::Needed
},
Ping::Needed => {
trace!(channel = debug(sink as *const _), msg = "sending ping");
let message = WebSocketMessage::Ping(Vec::new());
set_message(sink, &mut self.ping, message);
self.ping.advance(sink, ctx)?;
Ping::Pending
},
Ping::Pending => {
error!(
channel = debug(sink as *const _),
msg = "server failed to respond to pings"
);
self.ping_state = Ping::Needed;
let err = WebSocketError::Io(io::Error::new(
io::ErrorKind::Other,
"server failed to respond to pings",
));
return Err(err)
},
};
Ok(())
},
StdPoll::Pending => Ok(()),
}
}
fn activity(&mut self) {
self.ping_state = Ping::NotNeeded;
}
}
#[derive(Debug)]
pub struct Builder<S> {
ping_interval: Option<Duration>,
send_pongs: bool,
_phantom: PhantomData<S>,
}
impl<S> Builder<S> {
pub fn set_ping_interval(mut self, interval: Option<Duration>) -> Builder<S> {
self.ping_interval = interval;
self
}
pub fn set_send_pongs(mut self, enable: bool) -> Builder<S> {
self.send_pongs = enable;
self
}
pub fn build(self, channel: S) -> Wrapper<S> {
Wrapper {
inner: channel,
pong: if self.send_pongs {
Some(SendMessageState::Unused)
} else {
None
},
ping: self.ping_interval.map(Pinger::new),
}
}
}
impl<S> Default for Builder<S> {
fn default() -> Self {
Self {
ping_interval: Some(Duration::from_secs(30)),
send_pongs: false,
_phantom: PhantomData,
}
}
}
#[derive(Debug)]
#[must_use = "streams do nothing unless polled"]
pub struct Wrapper<S> {
inner: S,
pong: Option<SendMessageState<WebSocketMessage>>,
ping: Option<Pinger>,
}
impl<S> Wrapper<S> {
pub fn builder() -> Builder<S> {
Builder::default()
}
}
impl<S> Stream for Wrapper<S>
where
S: Sink<WebSocketMessage, Error = WebSocketError>
+ Stream<Item = Result<WebSocketMessage, WebSocketError>>
+ Unpin,
{
type Item = Result<Message, WebSocketError>;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = Pin::get_mut(self);
if let Some(pong) = &mut this.pong {
if let Err(err) = pong.advance(&mut this.inner, ctx) {
return Poll::Ready(Some(Err(err)))
}
}
if let Some(ping) = &mut this.ping {
if let Err(err) = ping.advance(&mut this.inner, ctx) {
return Poll::Ready(Some(Err(err)))
}
}
loop {
match this.inner.poll_next_unpin(ctx) {
Poll::Pending => {
break Poll::Pending
},
Poll::Ready(None) => {
break Poll::Ready(None)
},
Poll::Ready(Some(Err(err))) => break Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(message))) => {
debug!(
channel = debug(&this.inner as *const _),
recv_msg = debug_message(&message)
);
let () = this.ping.as_mut().map(Pinger::activity).unwrap_or(());
match message {
WebSocketMessage::Text(data) => break Poll::Ready(Some(Ok(Message::Text(data)))),
WebSocketMessage::Binary(data) => break Poll::Ready(Some(Ok(Message::Binary(data)))),
WebSocketMessage::Ping(data) => {
if let Some(pong) = &mut this.pong {
let message = WebSocketMessage::Pong(data);
set_message(&this.inner, pong, message);
if let Err(err) = pong.advance(&mut this.inner, ctx) {
return Poll::Ready(Some(Err(err)))
}
}
},
WebSocketMessage::Pong(_) => {
},
WebSocketMessage::Close(_) => {
},
WebSocketMessage::Frame(_) => {
},
}
},
}
}
}
}
impl<S> Sink<Message> for Wrapper<S>
where
S: Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
{
type Error = WebSocketError;
fn poll_ready(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready_unpin(ctx)
}
fn start_send(mut self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> {
let message = message.into();
debug!(
channel = debug(&self.inner as *const _),
send_msg = debug_message(&message)
);
self.inner.start_send_unpin(message)
}
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
trace!(channel = debug(&self.inner as *const _), msg = "flushing");
self.inner.poll_flush_unpin(ctx)
}
fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_close_unpin(ctx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::future::Future;
use futures::future::ready;
use futures::TryStreamExt as _;
use rand::seq::IteratorRandom as _;
use rand::thread_rng;
use rand::Rng as _;
use test_log::test;
use tokio::time::sleep;
use tokio::time::timeout;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::error::ProtocolError;
use url::Url;
use crate::test::mock_server;
use crate::test::WebSocketStream;
#[test]
fn debug_websocket_message() {
let message = WebSocketMessage::Binary(b"this is a test".to_vec());
let expected = r#"Binary("this is a test")"#;
assert_eq!(format!("{:?}", debug_message(&message)), expected);
let message = WebSocketMessage::Binary([0xf0, 0x90, 0x80].to_vec());
let expected = r#"Binary([240, 144, 128])"#;
assert_eq!(format!("{:?}", debug_message(&message)), expected);
let message = WebSocketMessage::Ping(Vec::new());
let expected = r#"Ping([])"#;
assert_eq!(format!("{:?}", debug_message(&message)), expected);
}
async fn serve_and_connect_with_builder<F, R>(
builder: Builder<WebSocketStream>,
f: F,
) -> Wrapper<WebSocketStream>
where
F: FnOnce(WebSocketStream) -> R + Send + Sync + 'static,
R: Future<Output = Result<(), WebSocketError>> + Send + Sync + 'static,
{
let addr = mock_server(f).await;
let url = Url::parse(&format!("ws://{}", addr)).unwrap();
let (stream, _) = connect_async(url).await.unwrap();
builder.build(stream)
}
async fn serve_and_connect<F, R>(f: F) -> Wrapper<WebSocketStream>
where
F: FnOnce(WebSocketStream) -> R + Send + Sync + 'static,
R: Future<Output = Result<(), WebSocketError>> + Send + Sync + 'static,
{
let ping = Some(Duration::from_millis(10));
let builder = Wrapper::builder().set_ping_interval(ping);
serve_and_connect_with_builder(builder, f).await
}
#[test(tokio::test)]
async fn no_messages() {
async fn test(_stream: WebSocketStream) -> Result<(), WebSocketError> {
Ok(())
}
let err = serve_and_connect(test)
.await
.try_for_each(|_| ready(Ok(())))
.await
.unwrap_err();
match err {
WebSocketError::Protocol(e) if e == ProtocolError::ResetWithoutClosingHandshake => (),
e => panic!("received unexpected error: {}", e),
}
}
#[test(tokio::test)]
async fn direct_close() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
stream.send(WebSocketMessage::Close(None)).await?;
Ok(())
}
serve_and_connect(test)
.await
.try_for_each(|_| ready(Ok(())))
.await
.unwrap();
}
#[test(tokio::test)]
async fn ping_pong() {
async fn test(stream: WebSocketStream) -> Result<(), WebSocketError> {
let mut stream = stream.fuse();
stream.send(WebSocketMessage::Ping(Vec::new())).await?;
assert_eq!(
stream.next().await.unwrap()?,
WebSocketMessage::Pong(Vec::new()),
);
let future = stream.select_next_some();
assert!(timeout(Duration::from_millis(20), future).await.is_err());
stream.send(WebSocketMessage::Close(None)).await?;
Ok(())
}
let builder = Wrapper::builder().set_ping_interval(None);
serve_and_connect_with_builder(builder, test)
.await
.try_for_each(|_| ready(Ok(())))
.await
.unwrap();
}
#[test(tokio::test)]
async fn ping_pong_2() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
stream.send(WebSocketMessage::Ping(Vec::new())).await?;
assert_eq!(
stream.next().await.unwrap()?,
WebSocketMessage::Pong(Vec::new()),
);
assert_eq!(
stream.next().await.unwrap()?,
WebSocketMessage::Pong(Vec::new()),
);
stream.send(WebSocketMessage::Close(None)).await?;
Ok(())
}
let builder = Wrapper::builder().set_send_pongs(true);
serve_and_connect_with_builder(builder, test)
.await
.try_for_each(|_| ready(Ok(())))
.await
.unwrap();
}
#[test(tokio::test)]
async fn pings_are_sent() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
for _ in 0..2 {
assert!(matches!(
stream.next().await.unwrap()?,
WebSocketMessage::Ping(_)
));
}
stream.send(WebSocketMessage::Close(None)).await?;
Ok(())
}
serve_and_connect(test)
.await
.try_for_each(|_| ready(Ok(())))
.await
.unwrap();
}
#[test(tokio::test)]
async fn no_pings_are_sent_when_disabled() {
async fn test(stream: WebSocketStream) -> Result<(), WebSocketError> {
let mut stream = stream.fuse();
let future = stream.select_next_some();
assert!(timeout(Duration::from_millis(20), future).await.is_err());
stream.send(WebSocketMessage::Close(None)).await?;
Ok(())
}
let builder = Wrapper::builder().set_ping_interval(None);
serve_and_connect_with_builder(builder, test)
.await
.try_for_each(|_| ready(Ok(())))
.await
.unwrap();
}
#[test(tokio::test)]
async fn no_pong_response() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
stream
.send(WebSocketMessage::Text("test".to_string()))
.await?;
sleep(Duration::from_secs(10)).await;
Ok(())
}
let stream = serve_and_connect(test).await;
let err = stream.try_for_each(|_| ready(Ok(()))).await.unwrap_err();
assert_eq!(
err.to_string(),
"IO error: server failed to respond to pings"
);
}
#[test(tokio::test)]
async fn send_messages() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
stream
.send(WebSocketMessage::Text("42".to_string()))
.await?;
stream.send(WebSocketMessage::Pong(Vec::new())).await?;
stream
.send(WebSocketMessage::Text("43".to_string()))
.await?;
stream.send(WebSocketMessage::Close(None)).await?;
Ok(())
}
let stream = serve_and_connect(test).await;
let messages = stream.try_collect::<Vec<_>>().await.unwrap();
assert_eq!(
messages,
vec![
Message::Text("42".to_string()),
Message::Text("43".to_string())
]
);
}
#[test(tokio::test)]
#[ignore = "stress test; test takes a long time"]
async fn stress_stream() {
async fn test(mut stream: WebSocketStream) -> Result<(), WebSocketError> {
fn random_buf() -> Vec<u8> {
let len = (0..32).choose(&mut thread_rng()).unwrap();
let mut vec = Vec::new();
vec.extend((0..len).map(|_| thread_rng().gen::<u8>()));
vec
}
for _ in 0..50000 {
let message = match (0..5).choose(&mut thread_rng()).unwrap() {
0 => WebSocketMessage::Pong(random_buf()),
i => {
if i & 0x1 == 0 {
let len = (0..32).choose(&mut thread_rng()).unwrap();
let mut string = String::new();
string.extend((0..len).map(|_| thread_rng().gen::<char>()));
WebSocketMessage::Text(string)
} else {
WebSocketMessage::Binary(random_buf())
}
},
};
stream.send(message).await?;
}
stream.send(WebSocketMessage::Close(None)).await?;
Ok(())
}
serve_and_connect(test)
.await
.try_for_each(|_| ready(Ok(())))
.await
.unwrap();
}
}