use std::marker::PhantomData;
use futures_util::stream::StreamExt;
use futures_util::SinkExt;
use gloo_net::websocket::futures::WebSocket;
use gloo_net::websocket::Message;
use crate::codec::{WsCodec, WsMessage};
use crate::connection::{RecvError, SendError};
use crate::WsEndpoint;
#[derive(Debug)]
pub struct ConnectError(gloo_net::websocket::WebSocketError);
impl std::fmt::Display for ConnectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "WebSocket connect error: {:?}", self.0)
}
}
pub fn connect<E: WsEndpoint>() -> Result<Connection<E>, ConnectError> {
let url = derive_ws_url(E::PATH);
connect_to::<E>(&url)
}
pub fn connect_to<E: WsEndpoint>(url: &str) -> Result<Connection<E>, ConnectError> {
let ws = WebSocket::open(url)
.map_err(|e| ConnectError(gloo_net::websocket::WebSocketError::MessageSendError(e)))?;
let (sink, stream) = ws.split();
Ok(Connection {
sink,
stream,
_types: PhantomData,
})
}
pub struct Connection<E: WsEndpoint> {
sink: futures_util::stream::SplitSink<WebSocket, Message>,
stream: futures_util::stream::SplitStream<WebSocket>,
_types: PhantomData<E>,
}
impl<E: WsEndpoint> Connection<E> {
pub async fn send(&mut self, msg: E::ClientMsg) -> Result<(), SendError> {
let ws_msg = msg.encode()?;
let gloo_msg = match ws_msg {
WsMessage::Text(t) => Message::Text(t),
WsMessage::Binary(b) => Message::Bytes(b),
};
self.sink
.send(gloo_msg)
.await
.map_err(|_| SendError::Closed)
}
pub async fn recv(&mut self) -> Option<Result<E::ServerMsg, RecvError>> {
loop {
match self.stream.next().await {
None => return None,
Some(Err(_)) => return Some(Err(RecvError::Closed)),
Some(Ok(msg)) => {
let ws_msg = match msg {
Message::Text(t) => WsMessage::Text(t),
Message::Bytes(b) => WsMessage::Binary(b),
};
return Some(E::ServerMsg::decode(ws_msg).map_err(RecvError::Decode));
}
}
}
}
pub fn split(self) -> (Sender<E>, Receiver<E>) {
(
Sender {
sink: self.sink,
_msg: PhantomData,
},
Receiver {
stream: self.stream,
_msg: PhantomData,
},
)
}
}
pub struct Sender<E: WsEndpoint> {
sink: futures_util::stream::SplitSink<WebSocket, Message>,
_msg: PhantomData<E>,
}
impl<E: WsEndpoint> Sender<E> {
pub async fn send(&mut self, msg: E::ClientMsg) -> Result<(), SendError> {
let ws_msg = msg.encode()?;
let gloo_msg = match ws_msg {
WsMessage::Text(t) => Message::Text(t),
WsMessage::Binary(b) => Message::Bytes(b),
};
self.sink
.send(gloo_msg)
.await
.map_err(|_| SendError::Closed)
}
}
pub struct Receiver<E: WsEndpoint> {
stream: futures_util::stream::SplitStream<WebSocket>,
_msg: PhantomData<E>,
}
impl<E: WsEndpoint> Receiver<E> {
pub async fn recv(&mut self) -> Option<Result<E::ServerMsg, RecvError>> {
loop {
match self.stream.next().await {
None => return None,
Some(Err(_)) => return Some(Err(RecvError::Closed)),
Some(Ok(msg)) => {
let ws_msg = match msg {
Message::Text(t) => WsMessage::Text(t),
Message::Bytes(b) => WsMessage::Binary(b),
};
return Some(E::ServerMsg::decode(ws_msg).map_err(RecvError::Decode));
}
}
}
}
}
fn derive_ws_url(path: &str) -> String {
let window = web_sys::window().expect("no window object");
let location = window.location();
let protocol = location.protocol().expect("no protocol");
let host = location.host().expect("no host");
let ws_protocol = if protocol == "https:" { "wss:" } else { "ws:" };
format!("{ws_protocol}//{host}{path}")
}