use bytes::Bytes;
use futures::{SinkExt, StreamExt, stream::SplitSink, stream::SplitStream};
use tokio_tungstenite::tungstenite::Message;
use crate::{Codec, Error, WebSocketStreamType};
pub struct HandshakeContext<'a, C: Codec> {
sink: &'a mut SplitSink<WebSocketStreamType, Message>,
stream: &'a mut SplitStream<WebSocketStreamType>,
codec: &'a C,
}
impl<'a, C: Codec> HandshakeContext<'a, C> {
pub(crate) fn new(
sink: &'a mut SplitSink<WebSocketStreamType, Message>,
stream: &'a mut SplitStream<WebSocketStreamType>,
codec: &'a C,
) -> Self {
Self {
sink,
stream,
codec,
}
}
pub async fn send(&mut self, value: &C::Tx) -> Result<(), Error> {
let message = self.codec.encode(value)?;
self.sink.send(message).await.map_err(Error::from)
}
pub async fn recv(&mut self) -> Result<C::Rx, Error> {
match self.recv_raw().await? {
Message::Close(_) => Err(Error::WebsocketClosed),
frame => self.codec.decode(&frame),
}
}
pub async fn send_text(&mut self, text: &str) -> Result<(), Error> {
self.sink
.send(Message::Text(text.into()))
.await
.map_err(Error::from)
}
pub async fn send_binary(&mut self, bytes: impl Into<Bytes>) -> Result<(), Error> {
self.sink
.send(Message::Binary(bytes.into()))
.await
.map_err(Error::from)
}
pub async fn recv_text(&mut self) -> Result<String, Error> {
match self.recv_raw().await? {
Message::Text(text) => Ok(text.to_string()),
other => Err(Error::UnexpectedMessageType(Box::new(other))),
}
}
pub async fn recv_raw(&mut self) -> Result<Message, Error> {
loop {
let Some(msg) = self.stream.next().await else {
return Err(Error::WebsocketClosed);
};
match msg.map_err(Error::WebsocketError)? {
Message::Ping(_) | Message::Pong(_) => {}
other => return Ok(other),
}
}
}
}
pub trait ConnectionHandler<C: Codec>: Send + 'static {
fn on_connected(
&mut self,
ctx: &mut HandshakeContext<'_, C>,
) -> impl std::future::Future<Output = Result<(), Error>> + Send {
let _ = ctx;
async { Ok(()) }
}
fn on_disconnected(&mut self) -> impl std::future::Future<Output = ()> + Send {
async {}
}
}
#[derive(Debug, Clone, Copy)]
pub struct NoopHandler;
impl<C: Codec> ConnectionHandler<C> for NoopHandler {}