use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures_util::{Sink, Stream};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_tungstenite::{
WebSocketStream,
tungstenite::{Error as WsError, Message},
};
#[derive(Debug)]
pub struct WsIo<S> {
ws: WebSocketStream<S>,
read_buf: Bytes,
}
impl<S> WsIo<S> {
pub fn new(ws: WebSocketStream<S>) -> Self {
Self {
ws,
read_buf: Bytes::new(),
}
}
pub fn into_inner(self) -> WebSocketStream<S> {
self.ws
}
}
impl<S> AsyncRead for WsIo<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if !self.read_buf.is_empty() {
let to_copy = self.read_buf.len().min(buf.remaining());
buf.put_slice(&self.read_buf[..to_copy]);
self.read_buf = self.read_buf.slice(to_copy..);
return Poll::Ready(Ok(()));
}
loop {
match Pin::new(&mut self.ws).poll_next(cx) {
Poll::Ready(Some(Ok(msg))) => match msg {
Message::Binary(data) => {
self.read_buf = data;
let to_copy = self.read_buf.len().min(buf.remaining());
buf.put_slice(&self.read_buf[..to_copy]);
self.read_buf = self.read_buf.slice(to_copy..);
return Poll::Ready(Ok(()));
}
Message::Text(text) => {
self.read_buf = Bytes::from(text.as_bytes().to_vec());
let to_copy = self.read_buf.len().min(buf.remaining());
buf.put_slice(&self.read_buf[..to_copy]);
self.read_buf = self.read_buf.slice(to_copy..);
return Poll::Ready(Ok(()));
}
Message::Ping(payload) => {
let mut ws = Pin::new(&mut self.ws);
match ws.as_mut().poll_ready(cx) {
Poll::Ready(Ok(())) => {
if let Err(err) = ws.start_send(Message::Pong(payload)) {
return Poll::Ready(Err(ws_err(err)));
}
continue;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(ws_err(err))),
Poll::Pending => return Poll::Pending,
}
}
Message::Pong(_) => continue,
Message::Close(_) => return Poll::Ready(Ok(())),
Message::Frame(_) => continue,
},
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(ws_err(err))),
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Pending => return Poll::Pending,
}
}
}
}
impl<S> AsyncWrite for WsIo<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
data: &[u8],
) -> Poll<std::io::Result<usize>> {
if data.is_empty() {
return Poll::Ready(Ok(0));
}
let mut ws = Pin::new(&mut self.ws);
match ws.as_mut().poll_ready(cx) {
Poll::Ready(Ok(())) => {
if let Err(err) = ws.start_send(Message::Binary(Bytes::copy_from_slice(data))) {
return Poll::Ready(Err(ws_err(err)));
}
Poll::Ready(Ok(data.len()))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(ws_err(err))),
Poll::Pending => Poll::Pending,
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let ws = Pin::new(&mut self.ws);
ws.poll_flush(cx).map_err(ws_err)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let ws = Pin::new(&mut self.ws);
ws.poll_close(cx).map_err(ws_err)
}
}
fn ws_err(err: WsError) -> std::io::Error {
std::io::Error::other(err)
}
#[cfg(test)]
mod tests {
}