use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::error::Error;
use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_tungstenite::WebSocketStream;
pub struct TypedWebSocketStream<S, INPUT, OUTPUT>
where
INPUT: Serialize,
OUTPUT: for<'de> Deserialize<'de>,
{
stream: WebSocketStream<S>,
_marker_in: std::marker::PhantomData<INPUT>,
_marker_out: std::marker::PhantomData<OUTPUT>,
}
impl<S, OUTPUT, INPUT> TypedWebSocketStream<S, INPUT, OUTPUT>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
INPUT: Serialize,
OUTPUT: for<'de> Deserialize<'de>,
{
pub fn new(stream: WebSocketStream<S>) -> Self {
Self {
stream,
_marker_in: std::marker::PhantomData,
_marker_out: std::marker::PhantomData,
}
}
pub async fn send(&mut self, message: INPUT) -> Result<(), Box<dyn Error>> {
let json = serde_json::to_string(&message)?; self.stream.send(Message::Text(json)).await?; Ok(())
}
pub async fn receive(&mut self) -> Result<OUTPUT, Box<dyn Error>> {
if let Some(Ok(Message::Text(json))) = self.stream.next().await {
let message: OUTPUT = serde_json::from_str(&json)?; Ok(message)
} else {
Err("Failed to receive valid text message".into())
}
}
pub async fn close(&mut self) -> Result<(), Box<dyn Error>> {
self.stream.send(Message::Close(None)).await?;
Ok(())
}
}