use crate::{WebSocket, WebSocketConn, WebSocketHandler};
use async_tungstenite::tungstenite::{Message, protocol::CloseFrame};
use futures_lite::{Stream, ready};
use serde::{Serialize, de::DeserializeOwned};
use std::{
fmt::Debug,
future::Future,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};
#[allow(unused_variables)]
pub trait JsonWebSocketHandler: Send + Sync + 'static {
type InboundMessage: DeserializeOwned + Send + 'static;
type OutboundMessage: Serialize + Send + 'static;
type StreamType: Stream<Item = Self::OutboundMessage> + Send + Sync + 'static;
fn connect(&self, conn: &mut WebSocketConn) -> impl Future<Output = Self::StreamType> + Send;
fn receive_message(
&self,
message: crate::Result<Self::InboundMessage>,
conn: &mut WebSocketConn,
) -> impl Future<Output = ()> + Send;
fn disconnect(
&self,
conn: &mut WebSocketConn,
close_frame: Option<CloseFrame>,
) -> impl Future<Output = ()> + Send {
async {}
}
}
pub struct JsonHandler<T> {
pub(crate) handler: T,
}
impl<T> Deref for JsonHandler<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.handler
}
}
impl<T> DerefMut for JsonHandler<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.handler
}
}
impl<T: JsonWebSocketHandler> JsonHandler<T> {
pub(crate) fn new(handler: T) -> Self {
Self { handler }
}
}
impl<T> Debug for JsonHandler<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JsonWebSocketHandler").finish()
}
}
pin_project_lite::pin_project! {
#[derive(Debug)]
pub struct SerializedStream<T> {
#[pin] inner: T
}
}
impl<T> Stream for SerializedStream<T>
where
T: Stream,
T::Item: Serialize,
{
type Item = Message;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(
ready!(self.project().inner.poll_next(cx))
.and_then(|i| match serde_json::to_string(&i) {
Ok(j) => Some(j),
Err(e) => {
log::error!("serialization error: {e}");
None
}
})
.map(Message::text),
)
}
}
impl<T> WebSocketHandler for JsonHandler<T>
where
T: JsonWebSocketHandler,
{
type OutboundStream = SerializedStream<T::StreamType>;
async fn connect(
&self,
mut conn: WebSocketConn,
) -> Option<(WebSocketConn, Self::OutboundStream)> {
let stream = SerializedStream {
inner: self.handler.connect(&mut conn).await,
};
Some((conn, stream))
}
async fn inbound(&self, message: Message, conn: &mut WebSocketConn) {
self.handler
.receive_message(
message
.to_text()
.map_err(Into::into)
.and_then(|m| serde_json::from_str(m).map_err(Into::into)),
conn,
)
.await;
}
async fn disconnect(&self, conn: &mut WebSocketConn, close_frame: Option<CloseFrame>) {
self.handler.disconnect(conn, close_frame).await
}
}
impl<T> WebSocket<JsonHandler<T>>
where
T: JsonWebSocketHandler,
{
pub fn new_json(handler: T) -> Self {
Self::new(JsonHandler::new(handler))
}
}
pub fn json_websocket<T>(json_websocket_handler: T) -> WebSocket<JsonHandler<T>>
where
T: JsonWebSocketHandler,
{
WebSocket::new_json(json_websocket_handler)
}