use std::cell::RefCell;
use std::io;
use std::rc::Rc;
use std::time::Duration;
use futures_util::StreamExt;
use roam_session::MessageTransport;
use roam_wire::Message;
use wasm_bindgen::JsCast;
use wasm_bindgen::closure::Closure;
use web_sys::{BinaryType, CloseEvent, ErrorEvent, MessageEvent, WebSocket};
pub struct WsTransport {
ws: WebSocket,
rx: futures_channel::mpsc::UnboundedReceiver<WsEvent>,
_closures: WsClosures,
last_decoded: Vec<u8>,
}
enum WsEvent {
Message(Vec<u8>),
Close,
Error(String),
}
struct WsClosures {
_onmessage: Closure<dyn FnMut(MessageEvent)>,
_onclose: Closure<dyn FnMut(CloseEvent)>,
_onerror: Closure<dyn FnMut(ErrorEvent)>,
}
impl WsTransport {
pub fn new(ws: WebSocket) -> Self {
ws.set_binary_type(BinaryType::Arraybuffer);
let (tx, rx) = futures_channel::mpsc::unbounded();
let tx_msg = tx.clone();
let onmessage = Closure::wrap(Box::new(move |e: MessageEvent| {
if let Ok(abuf) = e.data().dyn_into::<js_sys::ArrayBuffer>() {
let array = js_sys::Uint8Array::new(&abuf);
let data = array.to_vec();
let _ = tx_msg.unbounded_send(WsEvent::Message(data));
}
}) as Box<dyn FnMut(MessageEvent)>);
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
let tx_close = tx.clone();
let onclose = Closure::wrap(Box::new(move |_: CloseEvent| {
let _ = tx_close.unbounded_send(WsEvent::Close);
}) as Box<dyn FnMut(CloseEvent)>);
ws.set_onclose(Some(onclose.as_ref().unchecked_ref()));
let tx_error = tx;
let onerror = Closure::wrap(Box::new(move |e: ErrorEvent| {
let msg = e.message();
let _ = tx_error.unbounded_send(WsEvent::Error(msg));
}) as Box<dyn FnMut(ErrorEvent)>);
ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
Self {
ws,
rx,
_closures: WsClosures {
_onmessage: onmessage,
_onclose: onclose,
_onerror: onerror,
},
last_decoded: Vec::new(),
}
}
pub async fn connect(url: &str) -> io::Result<Self> {
let ws = WebSocket::new(url)
.map_err(|e| io::Error::other(format!("failed to create WebSocket: {e:?}")))?;
let (open_tx, open_rx) = futures_channel::oneshot::channel::<Result<(), String>>();
let open_tx = Rc::new(RefCell::new(Some(open_tx)));
let open_tx_clone = open_tx.clone();
let onopen = Closure::once(Box::new(move || {
if let Some(tx) = open_tx_clone.borrow_mut().take() {
let _ = tx.send(Ok(()));
}
}) as Box<dyn FnOnce()>);
ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
let open_tx_clone = open_tx.clone();
let onerror_temp = Closure::once(Box::new(move |e: ErrorEvent| {
if let Some(tx) = open_tx_clone.borrow_mut().take() {
let _ = tx.send(Err(e.message()));
}
}) as Box<dyn FnOnce(ErrorEvent)>);
ws.set_onerror(Some(onerror_temp.as_ref().unchecked_ref()));
let result = open_rx
.await
.map_err(|_| io::Error::other("connection cancelled"))?;
ws.set_onopen(None);
ws.set_onerror(None);
drop(onopen);
drop(onerror_temp);
result.map_err(|e| io::Error::other(format!("connection failed: {e}")))?;
Ok(Self::new(ws))
}
pub fn websocket(&self) -> &WebSocket {
&self.ws
}
pub fn close(&self) -> io::Result<()> {
self.ws
.close()
.map_err(|e| io::Error::other(format!("close failed: {e:?}")))
}
}
impl MessageTransport for WsTransport {
async fn send(&mut self, msg: &Message) -> io::Result<()> {
let payload = facet_postcard::to_vec(msg)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
self.ws
.send_with_u8_array(&payload)
.map_err(|e| io::Error::other(format!("send failed: {e:?}")))?;
Ok(())
}
async fn recv_timeout(&mut self, timeout: Duration) -> io::Result<Option<Message>> {
roam_session::runtime::timeout(timeout, self.recv())
.await
.unwrap_or(Ok(None))
}
async fn recv(&mut self) -> io::Result<Option<Message>> {
loop {
match self.rx.next().await {
Some(WsEvent::Message(data)) => {
self.last_decoded = data.clone();
let msg: Message = facet_postcard::from_slice(&data).map_err(|e| {
web_sys::console::error_1(
&format!(
"postcard decode failed: {e}, bytes ({} total): {:?}",
data.len(),
&data[..data.len().min(100)]
)
.into(),
);
io::Error::new(io::ErrorKind::InvalidData, format!("postcard: {e}"))
})?;
return Ok(Some(msg));
}
Some(WsEvent::Close) => {
return Ok(None);
}
Some(WsEvent::Error(e)) => {
return Err(io::Error::other(format!("WebSocket error: {e}")));
}
None => {
return Ok(None);
}
}
}
}
fn last_decoded(&self) -> &[u8] {
&self.last_decoded
}
}
impl Drop for WsTransport {
fn drop(&mut self) {
self.ws.set_onmessage(None);
self.ws.set_onclose(None);
self.ws.set_onerror(None);
let _ = self.ws.close();
}
}