use futures::{SinkExt, StreamExt, stream::SplitSink, stream::SplitStream};
use serde::{Deserialize, Serialize};
use tokio_tungstenite::tungstenite::Message;
use crate::{Error, WebSocketStreamType};
pub struct HandshakeContext<'a> {
sink: &'a mut SplitSink<WebSocketStreamType, Message>,
stream: &'a mut SplitStream<WebSocketStreamType>,
}
impl<'a> HandshakeContext<'a> {
pub(crate) fn new(
sink: &'a mut SplitSink<WebSocketStreamType, Message>,
stream: &'a mut SplitStream<WebSocketStreamType>,
) -> Self {
Self { sink, stream }
}
pub async fn send_text(&mut self, text: &str) -> Result<(), Error> {
self.sink
.send(Message::Text(text.into()))
.await
.map_err(Error::from)
}
pub async fn recv_text(&mut self) -> Result<String, Error> {
loop {
let Some(msg) = self.stream.next().await else {
return Err(Error::WebsocketClosed);
};
match msg.map_err(Error::WebsocketError)? {
Message::Text(text) => return Ok(text.to_string()),
Message::Ping(_) | Message::Pong(_) => {}
other => return Err(Error::UnexpectedMessageType(Box::new(other))),
}
}
}
pub async fn send_json<T: Serialize>(&mut self, msg: &T) -> Result<(), Error> {
let text = serde_json::to_string(msg)?;
self.send_text(&text).await
}
pub async fn recv_json<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T, Error> {
let text = self.recv_text().await?;
serde_json::from_str(&text).map_err(Error::from)
}
}
pub trait ConnectionHandler: Send + 'static {
fn on_connected(
&mut self,
ctx: &mut HandshakeContext<'_>,
) -> impl std::future::Future<Output = Result<(), Error>> + Send {
let _ = ctx;
async { Ok(()) }
}
fn on_disconnected(&mut self) -> impl std::future::Future<Output = ()> + Send {
async {}
}
}
#[derive(Debug, Clone, Copy)]
pub struct NoopHandler;
impl ConnectionHandler for NoopHandler {}