use std::cell::RefCell;
use std::rc::Rc;
use futures::StreamExt;
use futures_channel::{mpsc, oneshot};
use wasm_bindgen::prelude::*;
use web_sys::{BinaryType, CloseEvent, ErrorEvent, MessageEvent, WebSocket};
use super::WsError;
pub struct WsSender {
ws: WebSocket,
_onopen: Closure<dyn FnMut(JsValue)>,
_onerror: Closure<dyn FnMut(ErrorEvent)>,
}
impl std::fmt::Debug for WsSender {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WsSender").finish_non_exhaustive()
}
}
pub struct WsReceiver {
ws: WebSocket,
rx: mpsc::UnboundedReceiver<Result<String, WsError>>,
_onmessage: Closure<dyn FnMut(MessageEvent)>,
_onclose: Closure<dyn FnMut(CloseEvent)>,
}
impl std::fmt::Debug for WsReceiver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WsReceiver").finish_non_exhaustive()
}
}
impl Drop for WsSender {
fn drop(&mut self) {
self.ws.set_onopen(None);
self.ws.set_onerror(None);
}
}
impl Drop for WsReceiver {
fn drop(&mut self) {
self.ws.set_onmessage(None);
self.ws.set_onclose(None);
}
}
impl WsSender {
pub async fn send(&mut self, text: String) -> Result<(), WsError> {
self.ws
.send_with_str(&text)
.map_err(|e| WsError::Send(format!("{:?}", e)))
}
pub async fn close(&mut self) -> Result<(), WsError> {
self.ws
.close()
.map_err(|e| WsError::Send(format!("{:?}", e)))
}
}
impl WsReceiver {
pub async fn recv(&mut self) -> Option<Result<String, WsError>> {
self.rx.next().await
}
}
pub async fn connect(
url: &str,
headers: &[(&str, &str)],
) -> Result<(WsSender, WsReceiver), WsError> {
if !headers.is_empty() {
tracing::warn!(
"WebSocket headers are not supported on WASM (browser limitation). \
{} header(s) will be ignored.",
headers.len()
);
}
let ws = WebSocket::new(url).map_err(|e| WsError::Connection(format!("{:?}", e)))?;
ws.set_binary_type(BinaryType::Arraybuffer);
let (msg_tx, msg_rx) = mpsc::unbounded::<Result<String, WsError>>();
let (open_tx, open_rx) = oneshot::channel::<Result<(), WsError>>();
let open_tx: Rc<RefCell<Option<oneshot::Sender<Result<(), WsError>>>>> =
Rc::new(RefCell::new(Some(open_tx)));
let open_tx_open = Rc::clone(&open_tx);
let onopen = Closure::<dyn FnMut(JsValue)>::new(move |_: JsValue| {
if let Some(tx) = open_tx_open.borrow_mut().take() {
let _ = tx.send(Ok(()));
}
});
ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
let msg_tx_msg: mpsc::UnboundedSender<Result<String, WsError>> = msg_tx.clone();
let onmessage = Closure::<dyn FnMut(MessageEvent)>::new(move |e: MessageEvent| {
if let Some(text) = e.data().as_string() {
let _ = msg_tx_msg.unbounded_send(Ok(text));
}
});
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
let open_tx_err = Rc::clone(&open_tx);
let msg_tx_err: mpsc::UnboundedSender<Result<String, WsError>> = msg_tx.clone();
let onerror = Closure::<dyn FnMut(ErrorEvent)>::new(move |_e: ErrorEvent| {
let err = WsError::Connection("WebSocket error".to_string());
if let Some(tx) = open_tx_err.borrow_mut().take() {
let _ = tx.send(Err(err));
} else {
let _ = msg_tx_err.unbounded_send(Err(err));
}
});
ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
let onclose = Closure::<dyn FnMut(CloseEvent)>::new(move |_e: CloseEvent| {
msg_tx.close_channel();
});
ws.set_onclose(Some(onclose.as_ref().unchecked_ref()));
open_rx
.await
.map_err(|_| WsError::Connection("open channel dropped".to_string()))??;
let ws_clone = ws.clone();
Ok((
WsSender {
ws,
_onopen: onopen,
_onerror: onerror,
},
WsReceiver {
ws: ws_clone,
rx: msg_rx,
_onmessage: onmessage,
_onclose: onclose,
},
))
}