use std::fmt;
use std::marker::PhantomData;
use futures::{SinkExt, StreamExt};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio_tungstenite::tungstenite;
use tokio_tungstenite::WebSocketStream;
use typeway_core::session::*;
pub struct TypedWebSocket<S: SessionType> {
inner: WebSocketStream<hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>>,
_state: PhantomData<S>,
}
impl<S: SessionType> TypedWebSocket<S> {
pub fn new(inner: WebSocketStream<hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>>) -> Self {
TypedWebSocket {
inner,
_state: PhantomData,
}
}
}
impl<T: Serialize, Next: SessionType> TypedWebSocket<Send<T, Next>> {
pub async fn send(mut self, msg: T) -> Result<TypedWebSocket<Next>, WebSocketError> {
let json = serde_json::to_string(&msg).map_err(WebSocketError::Serialize)?;
self.inner
.send(tungstenite::Message::Text(json.into()))
.await
.map_err(WebSocketError::Transport)?;
Ok(TypedWebSocket {
inner: self.inner,
_state: PhantomData,
})
}
}
impl<T: DeserializeOwned, Next: SessionType> TypedWebSocket<Recv<T, Next>> {
pub async fn recv(mut self) -> Result<(T, TypedWebSocket<Next>), WebSocketError> {
loop {
match self.inner.next().await {
Some(Ok(msg)) if msg.is_text() => {
let text = msg
.into_text()
.map_err(|_| WebSocketError::Protocol("expected text frame".into()))?;
let val: T =
serde_json::from_str(&text).map_err(WebSocketError::Deserialize)?;
return Ok((
val,
TypedWebSocket {
inner: self.inner,
_state: PhantomData,
},
));
}
Some(Ok(msg)) if msg.is_close() => return Err(WebSocketError::Closed),
Some(Ok(_)) => continue, Some(Err(e)) => return Err(WebSocketError::Transport(e)),
None => return Err(WebSocketError::Closed),
}
}
}
}
impl TypedWebSocket<End> {
pub async fn close(mut self) -> Result<(), WebSocketError> {
self.inner
.close(None)
.await
.map_err(WebSocketError::Transport)
}
}
impl<L: SessionType, R: SessionType> TypedWebSocket<Offer<L, R>> {
pub async fn offer(
mut self,
) -> Result<Either<TypedWebSocket<L>, TypedWebSocket<R>>, WebSocketError> {
loop {
match self.inner.next().await {
Some(Ok(msg)) if msg.is_text() => {
let text = msg
.into_text()
.map_err(|_| WebSocketError::Protocol("expected text frame".into()))?;
if text.contains("\"branch\":\"L\"") || text.contains("\"branch\":\"left\"") {
return Ok(Either::Left(TypedWebSocket {
inner: self.inner,
_state: PhantomData,
}));
} else {
return Ok(Either::Right(TypedWebSocket {
inner: self.inner,
_state: PhantomData,
}));
}
}
Some(Ok(msg)) if msg.is_close() => return Err(WebSocketError::Closed),
Some(Ok(_)) => continue,
Some(Err(e)) => return Err(WebSocketError::Transport(e)),
None => return Err(WebSocketError::Closed),
}
}
}
}
impl<L: SessionType, R: SessionType> TypedWebSocket<Select<L, R>> {
pub async fn select_left(mut self) -> Result<TypedWebSocket<L>, WebSocketError> {
self.inner
.send(tungstenite::Message::Text("{\"branch\":\"L\"}".into()))
.await
.map_err(WebSocketError::Transport)?;
Ok(TypedWebSocket {
inner: self.inner,
_state: PhantomData,
})
}
pub async fn select_right(mut self) -> Result<TypedWebSocket<R>, WebSocketError> {
self.inner
.send(tungstenite::Message::Text("{\"branch\":\"R\"}".into()))
.await
.map_err(WebSocketError::Transport)?;
Ok(TypedWebSocket {
inner: self.inner,
_state: PhantomData,
})
}
}
impl<B: SessionType> TypedWebSocket<Rec<B>> {
pub fn enter(self) -> TypedWebSocket<B> {
TypedWebSocket {
inner: self.inner,
_state: PhantomData,
}
}
}
impl TypedWebSocket<Var> {
pub fn recurse<B: SessionType>(self) -> TypedWebSocket<Rec<B>> {
TypedWebSocket {
inner: self.inner,
_state: PhantomData,
}
}
}
pub enum Either<L, R> {
Left(L),
Right(R),
}
#[derive(Debug)]
pub enum WebSocketError {
Transport(tungstenite::Error),
Serialize(serde_json::Error),
Deserialize(serde_json::Error),
Protocol(String),
Closed,
}
impl fmt::Display for WebSocketError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WebSocketError::Transport(e) => write!(f, "transport error: {e}"),
WebSocketError::Serialize(e) => write!(f, "serialization error: {e}"),
WebSocketError::Deserialize(e) => write!(f, "deserialization error: {e}"),
WebSocketError::Protocol(msg) => write!(f, "protocol error: {msg}"),
WebSocketError::Closed => write!(f, "connection closed"),
}
}
}
impl std::error::Error for WebSocketError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
WebSocketError::Transport(e) => Some(e),
WebSocketError::Serialize(e) => Some(e),
WebSocketError::Deserialize(e) => Some(e),
_ => None,
}
}
}