use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use futures_util::Sink;
use futures_util::stream::Stream;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Message;
pub struct WsStream<S> {
inner: WebSocketStream<S>,
read_buf: BytesMut,
write_buf: BytesMut,
read_closed: bool,
}
impl<S> WsStream<S> {
pub fn new(inner: WebSocketStream<S>) -> Self {
Self {
inner,
read_buf: BytesMut::new(),
write_buf: BytesMut::new(),
read_closed: false,
}
}
}
impl<S> AsyncRead for WsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if !this.read_buf.is_empty() {
let to_copy = this.read_buf.len().min(buf.remaining());
buf.put_slice(&this.read_buf.split_to(to_copy));
return Poll::Ready(Ok(()));
}
if this.read_closed {
return Poll::Ready(Ok(()));
}
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(Ok(msg))) => match msg {
Message::Binary(data) => {
let to_copy = data.len().min(buf.remaining());
buf.put_slice(&data[..to_copy]);
if to_copy < data.len() {
this.read_buf.extend_from_slice(&data[to_copy..]);
}
Poll::Ready(Ok(()))
}
Message::Close(_) => {
this.read_closed = true;
Poll::Ready(Ok(()))
}
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Message::Text(_) => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"Bolt requires binary WebSocket frames, received text",
))),
},
Poll::Ready(Some(Err(e))) => {
Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionAborted, e)))
}
Poll::Ready(None) => {
this.read_closed = true;
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
}
impl<S> AsyncWrite for WsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.write_buf.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
if !this.write_buf.is_empty() {
match Pin::new(&mut this.inner).poll_ready(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, e)));
}
Poll::Pending => return Poll::Pending,
}
let data = this.write_buf.split().freeze().to_vec();
if let Err(e) = Pin::new(&mut this.inner).start_send(Message::Binary(data.into())) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, e)));
}
}
Pin::new(&mut this.inner)
.poll_flush(cx)
.map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
match Pin::new(&mut this.inner).poll_ready(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, e)));
}
Poll::Pending => return Poll::Pending,
}
let _ = Pin::new(&mut this.inner).start_send(Message::Close(None));
Pin::new(&mut this.inner)
.poll_close(cx)
.map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))
}
}
#[cfg(test)]
mod tests {
use futures_util::SinkExt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::*;
async fn ws_pair() -> (
WsStream<impl AsyncRead + AsyncWrite + Unpin>,
WebSocketStream<impl AsyncRead + AsyncWrite + Unpin>,
) {
let (client_io, server_io) = tokio::io::duplex(64 * 1024);
let (client_task, server_task) = tokio::join!(
tokio_tungstenite::client_async("ws://localhost/bolt", client_io),
tokio_tungstenite::accept_async(server_io),
);
let (client_ws, _response) = client_task.expect("client WS handshake");
let server_ws = server_task.expect("server WS handshake");
(WsStream::new(client_ws), server_ws)
}
#[tokio::test]
async fn read_binary_message() {
let (mut ws, mut server) = ws_pair().await;
server
.send(Message::Binary(vec![0xAA, 0xBB, 0xCC].into()))
.await
.unwrap();
let mut buf = [0u8; 3];
ws.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, [0xAA, 0xBB, 0xCC]);
}
#[tokio::test]
async fn read_buffered_across_calls() {
let (mut ws, mut server) = ws_pair().await;
server
.send(Message::Binary(vec![1, 2, 3, 4].into()))
.await
.unwrap();
let mut buf = [0u8; 2];
ws.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, [1, 2]);
ws.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, [3, 4]);
}
#[tokio::test]
async fn write_and_flush_produces_binary_message() {
let (mut ws, mut server) = ws_pair().await;
ws.write_all(&[0x01, 0x02, 0x03]).await.unwrap();
ws.flush().await.unwrap();
use futures_util::StreamExt;
let msg = server.next().await.unwrap().unwrap();
match msg {
Message::Binary(data) => assert_eq!(&data[..], &[0x01, 0x02, 0x03]),
other => panic!("expected Binary, got {other:?}"),
}
}
#[tokio::test]
async fn text_frame_rejected() {
let (mut ws, mut server) = ws_pair().await;
server.send(Message::Text("hello".into())).await.unwrap();
let mut buf = [0u8; 5];
let result = ws.read_exact(&mut buf).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("binary"));
}
#[tokio::test]
async fn close_frame_produces_eof() {
let (mut ws, mut server) = ws_pair().await;
server.send(Message::Close(None)).await.unwrap();
let mut buf = [0u8; 1];
let n = ws.read(&mut buf).await.unwrap();
assert_eq!(n, 0);
}
}