use crate::{WebSocket, WebSocketConn, WebSocketHandler};
use async_tungstenite::tungstenite::{protocol::CloseFrame, Message};
use futures_lite::{ready, Stream, StreamExt};
use serde::{de::DeserializeOwned, Serialize};
use std::{
fmt::Debug,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};
use trillium::async_trait;
#[allow(unused_variables)]
#[async_trait]
pub trait JsonWebSocketHandler: Send + Sync + 'static {
type InboundMessage: DeserializeOwned + Send + 'static;
type OutboundMessage: Serialize + Send + 'static;
type StreamType: Stream<Item = Self::OutboundMessage> + Unpin + Send + Sync + 'static;
async fn connect(&self, conn: &mut WebSocketConn) -> Self::StreamType;
async fn receive_message(&self, message: Self::InboundMessage, conn: &mut WebSocketConn);
async fn disconnect(&self, conn: &mut WebSocketConn, close_frame: Option<CloseFrame<'static>>) {
}
}
pub struct JsonHandler<T> {
pub(crate) handler: T,
}
impl<T> Deref for JsonHandler<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.handler
}
}
impl<T> DerefMut for JsonHandler<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.handler
}
}
impl<T: JsonWebSocketHandler> JsonHandler<T> {
pub(crate) fn new(handler: T) -> Self {
Self { handler }
}
}
impl<T> Debug for JsonHandler<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JsonWebSocketHandler").finish()
}
}
#[derive(Debug)]
pub struct SerializedStream<T>(T);
impl<T> Stream for SerializedStream<T>
where
T: Stream + Unpin,
T::Item: Serialize,
{
type Item = Message;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(
ready!(self.0.poll_next(cx))
.and_then(|i| serde_json::to_string(&i).ok())
.map(Message::Text),
)
}
}
#[async_trait]
impl<T> WebSocketHandler for JsonHandler<T>
where
T: JsonWebSocketHandler,
{
type OutboundStream = SerializedStream<T::StreamType>;
async fn connect(
&self,
mut conn: WebSocketConn,
) -> Option<(WebSocketConn, Self::OutboundStream)> {
let stream = SerializedStream(self.handler.connect(&mut conn).await);
Some((conn, stream))
}
async fn inbound(&self, message: Message, conn: &mut WebSocketConn) {
if let Some(message) = message
.to_text()
.ok()
.and_then(|m| serde_json::from_str(m).ok())
{
self.handler.receive_message(message, conn).await;
}
}
async fn disconnect(&self, conn: &mut WebSocketConn, close_frame: Option<CloseFrame<'static>>) {
self.handler.disconnect(conn, close_frame).await
}
}
impl<T> WebSocket<JsonHandler<T>>
where
T: JsonWebSocketHandler,
{
pub fn new_json(handler: T) -> Self {
Self::new(JsonHandler::new(handler))
}
}
pub fn json_websocket<T>(json_websocket_handler: T) -> WebSocket<JsonHandler<T>>
where
T: JsonWebSocketHandler,
{
WebSocket::new_json(json_websocket_handler)
}