#![allow(clippy::unwrap_used)]
use std::{cell::RefCell, rc::Rc};
use async_trait::async_trait;
use derive_more::{From, Into};
use futures::{channel::mpsc, stream::LocalBoxStream, StreamExt};
use medea_client_api_proto::{ClientMsg, ServerMsg};
use medea_reactive::ObservableCell;
use tracerr::Traced;
use web_sys::{CloseEvent, Event, MessageEvent, WebSocket as SysWebSocket};
use crate::{
platform::{
transport::{RpcTransport, TransportError, TransportState},
wasm::utils::EventListener,
},
rpc::{websocket::ClientDisconnect, ApiUrl, CloseMsg},
};
#[derive(Clone, From, Into)]
struct ServerMessage(ServerMsg);
impl TryFrom<&MessageEvent> for ServerMessage {
type Error = TransportError;
fn try_from(msg: &MessageEvent) -> Result<Self, Self::Error> {
use TransportError::{MessageNotString, ParseServerMessage};
let payload = msg.data().as_string().ok_or(MessageNotString)?;
serde_json::from_str::<ServerMsg>(&payload)
.map_err(|e| ParseServerMessage(e.into()))
.map(Self::from)
}
}
type TransportResult<T> = Result<T, Traced<TransportError>>;
#[derive(Debug)]
struct InnerSocket {
socket: RefCell<Option<SysWebSocket>>,
socket_state: ObservableCell<TransportState>,
on_open_listener: Option<EventListener<SysWebSocket, Event>>,
on_message_listener: Option<EventListener<SysWebSocket, MessageEvent>>,
on_close_listener: Option<EventListener<SysWebSocket, CloseEvent>>,
on_message_subs: Vec<mpsc::UnboundedSender<ServerMsg>>,
close_reason: ClientDisconnect,
}
impl InnerSocket {
const fn new() -> Self {
Self {
socket_state: ObservableCell::new(TransportState::Connecting),
socket: RefCell::new(None),
on_open_listener: None,
on_message_listener: None,
on_close_listener: None,
on_message_subs: Vec::new(),
close_reason: ClientDisconnect::RpcTransportUnexpectedlyDropped,
}
}
}
impl Drop for InnerSocket {
fn drop(&mut self) {
if self.socket_state.borrow().can_close() {
let rsn =
serde_json::to_string(&self.close_reason).unwrap_or_else(|e| {
panic!("Could not serialize close message: {e}")
});
if let Some(socket) = self.socket.borrow().as_ref() {
if let Err(e) = socket.close_with_code_and_reason(1000, &rsn) {
log::error!("Failed to normally close socket: {e:?}");
}
}
}
}
}
#[derive(Debug)]
pub struct WebSocketRpcTransport(Rc<RefCell<InnerSocket>>);
impl WebSocketRpcTransport {
#[must_use]
pub fn new() -> Self {
Self(Rc::new(RefCell::new(InnerSocket::new())))
}
fn set_on_close_listener(&self, socket: SysWebSocket) {
let this = Rc::clone(&self.0);
let on_close = EventListener::new_once(
Rc::new(socket),
"close",
move |msg: CloseEvent| {
this.borrow().socket_state.set(TransportState::Closed(
CloseMsg::from((msg.code(), msg.reason())),
));
},
)
.unwrap();
self.0.borrow_mut().on_close_listener = Some(on_close);
}
fn set_on_message_listener(&self, socket: SysWebSocket) {
let this = Rc::clone(&self.0);
let on_message =
EventListener::new_mut(Rc::new(socket), "message", move |msg| {
let msg =
match ServerMessage::try_from(&msg).map(ServerMsg::from) {
Ok(parsed) => parsed,
Err(e) => {
log::error!("{}", tracerr::new!(e));
return;
}
};
let mut this_mut = this.borrow_mut();
this_mut.on_message_subs.retain(|on_message| {
on_message.unbounded_send(msg.clone()).is_ok()
});
})
.unwrap();
self.0.borrow_mut().on_message_listener = Some(on_message);
}
}
impl Default for WebSocketRpcTransport {
fn default() -> Self {
Self::new()
}
}
#[async_trait(?Send)]
impl RpcTransport for WebSocketRpcTransport {
async fn connect(&self, url: ApiUrl) -> TransportResult<()> {
let socket = SysWebSocket::new(url.as_ref())
.map_err(Into::into)
.map_err(TransportError::CreateSocket)
.map_err(tracerr::wrap!())?;
*self.0.borrow_mut().socket.borrow_mut() = Some(socket.clone());
{
{
let inner = Rc::clone(&self.0);
self.0.borrow_mut().on_close_listener = Some(
EventListener::new_once(
Rc::clone(&Rc::new(socket.clone())),
"close",
move |msg: CloseEvent| {
inner.borrow().socket_state.set(
TransportState::Closed(CloseMsg::from((
msg.code(),
msg.reason(),
))),
);
},
)
.unwrap(),
);
}
{
let inner = Rc::clone(&self.0);
self.0.borrow_mut().on_open_listener = Some(
EventListener::new_once(
Rc::clone(&Rc::new(socket.clone())),
"open",
move |_| {
inner
.borrow()
.socket_state
.set(TransportState::Open);
},
)
.unwrap(),
);
}
}
let state_updates_rx = self.0.borrow().socket_state.subscribe();
let state = state_updates_rx.skip(1).next().await;
if state == Some(TransportState::Open) {
self.set_on_close_listener(socket.clone());
self.set_on_message_listener(socket);
Ok(())
} else {
Err(tracerr::new!(TransportError::InitSocket))
}
}
fn on_message(&self) -> LocalBoxStream<'static, ServerMsg> {
let (tx, rx) = mpsc::unbounded();
self.0.borrow_mut().on_message_subs.push(tx);
Box::pin(rx)
}
fn set_close_reason(&self, reason: ClientDisconnect) {
self.0.borrow_mut().close_reason = reason;
}
fn send(&self, msg: &ClientMsg) -> TransportResult<()> {
let inner = self.0.borrow();
let message = serde_json::to_string(msg)
.map_err(|e| TransportError::SerializeClientMessage(e.into()))
.map_err(tracerr::wrap!())?;
let state = &*inner.socket_state.borrow();
match state {
TransportState::Open => inner.socket.borrow().as_ref().map_or_else(
|| Err(tracerr::new!(TransportError::ClosedSocket)),
|socket| {
socket
.send_with_str(&message)
.map_err(Into::into)
.map_err(TransportError::SendMessage)
.map_err(tracerr::wrap!())
},
),
TransportState::Connecting
| TransportState::Closing
| TransportState::Closed(_) => {
Err(tracerr::new!(TransportError::ClosedSocket))
}
}
}
fn on_state_change(&self) -> LocalBoxStream<'static, TransportState> {
self.0.borrow().socket_state.subscribe()
}
}
impl Drop for WebSocketRpcTransport {
fn drop(&mut self) {
let mut inner = self.0.borrow_mut();
drop(inner.on_open_listener.take());
drop(inner.on_message_listener.take());
drop(inner.on_close_listener.take());
}
}