use std::future::ready;
use std::io::ErrorKind;
use bytes::Bytes;
use futures::Sink;
use futures::SinkExt as _;
use futures::Stream;
use futures::StreamExt as _;
use nameth::nameth;
use tokio::sync::oneshot;
use tokio_util::io::CopyToBytes;
use tokio_util::io::SinkWriter;
use tokio_util::io::StreamReader;
pub trait WebSocketIo {
type Message;
type Error: std::error::Error + Send + Sync + 'static;
fn into_data(message: Self::Message) -> Bytes;
fn into_messsge(bytes: Bytes) -> Self::Message;
fn to_async_io(
web_socket: impl Stream<Item = Result<Self::Message, Self::Error>>
+ Sink<Self::Message, Error = Self::Error>,
) -> (
impl tokio::io::AsyncRead + tokio::io::AsyncWrite,
impl Future<Output = std::io::Result<()>>,
)
where
Self: Sized,
{
to_async_io_impl::<Self>(web_socket)
}
}
#[nameth]
#[derive(thiserror::Error, Debug)]
#[error("[{READ_ERROR}] {0}")]
struct ReadError<E>(E);
#[nameth]
#[derive(thiserror::Error, Debug)]
#[error("[{WRITE_ERROR}] {0}")]
struct WriteError<E>(E);
fn to_async_io_impl<IO: WebSocketIo>(
web_socket: impl Stream<Item = Result<IO::Message, IO::Error>>
+ Sink<IO::Message, Error = IO::Error>,
) -> (
impl tokio::io::AsyncRead + tokio::io::AsyncWrite,
impl Future<Output = std::io::Result<()>>,
) {
let (error_tx, error_rx) = oneshot::channel();
let mut error_tx = Some(error_tx);
let (sink, stream) = web_socket.split();
let reader = {
StreamReader::new(stream.map(move |message| {
let message = message.map(IO::into_data).map_err(|error: IO::Error| {
let error = std::io::Error::new(ErrorKind::ConnectionAborted, ReadError(error));
let error_tx = error_tx.take();
error_tx.map(|error_tx| error_tx.send(error));
return ErrorKind::ConnectionAborted;
});
return message;
}))
};
let writer = {
let sink = CopyToBytes::new(sink.with(|data| ready(Ok(IO::into_messsge(data)))))
.sink_map_err(|error: IO::Error| {
std::io::Error::new(ErrorKind::ConnectionAborted, WriteError(error))
});
SinkWriter::new(sink)
};
let eos = Box::pin(async {
match error_rx.await {
Ok(error) => Err(error),
Err(oneshot::error::RecvError { .. }) => Ok(()),
}
});
(tokio::io::join(reader, writer), eos)
}