use std::{
future, io,
pin::Pin,
task::{Context, Poll, Waker},
};
use bytes::Bytes;
use futures_core::Stream as _;
use futures_sink::Sink;
use http::Uri;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_websockets::{ClientBuilder, Message, WebSocketStream};
use watermelon_proto::proto::{
ClientOp, FramedEncoder, ServerOp, decode_frame, error::FrameDecoderError,
};
#[derive(Debug)]
pub struct WebsocketConnection<S> {
socket: WebSocketStream<S>,
encoder: FramedEncoder,
residual_frame: Bytes,
should_flush: bool,
}
impl<S> WebsocketConnection<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub async fn new(uri: Uri, socket: S) -> io::Result<Self> {
let (socket, _resp) = ClientBuilder::from_uri(uri)
.connect_on(socket)
.await
.map_err(websockets_error_to_io)?;
Ok(Self {
socket,
encoder: FramedEncoder::new(),
residual_frame: Bytes::new(),
should_flush: false,
})
}
pub fn poll_read_next(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<ServerOp, WebsocketReadError>> {
loop {
if !self.residual_frame.is_empty() {
return Poll::Ready(
decode_frame(&mut self.residual_frame).map_err(WebsocketReadError::Decoder),
);
}
match Pin::new(&mut self.socket).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(message))) if message.is_binary() => {
self.residual_frame = message.into_payload().into();
}
Poll::Ready(Some(Ok(_message))) => {}
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(Err(WebsocketReadError::Io(websockets_error_to_io(err))));
}
Poll::Ready(None) => return Poll::Ready(Err(WebsocketReadError::Closed)),
}
}
}
pub async fn read_next(&mut self) -> Result<ServerOp, WebsocketReadError> {
future::poll_fn(|cx| self.poll_read_next(cx)).await
}
pub fn should_flush(&self) -> bool {
self.should_flush
}
pub fn may_enqueue_more_ops(&mut self) -> bool {
let mut cx = Context::from_waker(Waker::noop());
Pin::new(&mut self.socket).poll_ready(&mut cx).is_ready()
}
#[expect(clippy::missing_panics_doc)]
pub fn enqueue_write_op(&mut self, item: &ClientOp) {
let payload = self.encoder.encode(item);
Pin::new(&mut self.socket)
.start_send(Message::binary(payload))
.unwrap();
self.should_flush = true;
}
pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.socket)
.poll_flush(cx)
.map_err(websockets_error_to_io)
}
pub async fn flush(&mut self) -> io::Result<()> {
future::poll_fn(|cx| self.poll_flush(cx)).await
}
pub async fn shutdown(&mut self) -> io::Result<()> {
future::poll_fn(|cx| Pin::new(&mut self.socket).poll_close(cx))
.await
.map_err(websockets_error_to_io)
}
}
#[derive(Debug, thiserror::Error)]
pub enum WebsocketReadError {
#[error("decoder")]
Decoder(#[source] FrameDecoderError),
#[error("io")]
Io(#[source] io::Error),
#[error("closed")]
Closed,
}
fn websockets_error_to_io(err: tokio_websockets::Error) -> io::Error {
match err {
tokio_websockets::Error::Io(err) => err,
err => io::Error::other(err),
}
}