use std::{
cell::{Cell, RefCell},
rc::Rc,
};
use async_trait::async_trait;
use derivative::Derivative;
use derive_more::with_trait::{Display, From};
use futures::{
StreamExt as _,
channel::mpsc,
future::{self, LocalBoxFuture},
stream::LocalBoxStream,
};
use medea_client_api_proto::{Command, Event, MemberId, RoomId};
use medea_reactive::ObservableCell;
use tracerr::Traced;
use crate::{
platform,
rpc::{
ClientDisconnect, CloseReason, ConnectionInfo, RpcClientError,
WebSocketRpcClient, websocket::RpcEventHandler,
},
utils::Caused,
};
#[derive(Caused, Clone, Debug, From, Display)]
#[cause(error = platform::Error)]
pub enum SessionError {
#[display("RPC Session finished with {_0:?} close reason")]
SessionFinished(CloseReason),
#[display("RPC Session doesn't have any credentials to authorize with")]
NoCredentials,
#[display("Failed to authorize RPC session")]
AuthorizationFailed,
#[display("RpcClientError: {_0}")]
RpcClient(#[cause] RpcClientError),
#[display("RPC Session was unexpectedly dropped")]
SessionUnexpectedlyDropped,
#[display("Connection with a server was lost: {_0}")]
ConnectionLost(ConnectionLostReason),
#[display("New connection info was provided")]
NewConnectionInfo,
}
#[derive(Clone, Debug, Display)]
pub enum ConnectionLostReason {
ConnectError(Traced<RpcClientError>),
Lost(super::ConnectionLostReason),
}
impl Caused for ConnectionLostReason {
type Error = platform::Error;
fn cause(self) -> Option<Self::Error> {
match self {
Self::ConnectError(err) => err.into_inner().cause(),
Self::Lost(_) => None,
}
}
}
#[async_trait(?Send)]
#[cfg_attr(feature = "mockable", mockall::automock)]
pub trait RpcSession {
async fn connect(
self: Rc<Self>,
connection_info: ConnectionInfo,
) -> Result<(), Traced<SessionError>>;
async fn reconnect(self: Rc<Self>) -> Result<(), Traced<SessionError>>;
fn subscribe(&self) -> LocalBoxStream<'static, Event>;
fn send_command(&self, command: Command);
fn on_normal_close(&self) -> LocalBoxFuture<'static, CloseReason>;
fn close_with_reason(&self, close_reason: ClientDisconnect);
fn on_connection_loss(&self) -> LocalBoxStream<'static, ()>;
fn on_reconnected(&self) -> LocalBoxStream<'static, ()>;
async fn network_changed(
self: Rc<Self>,
) -> Result<(), Traced<SessionError>>;
}
#[derive(Debug)]
pub struct WebSocketRpcSession {
client: Rc<WebSocketRpcClient>,
state: ObservableCell<SessionState>,
was_connected: Rc<Cell<bool>>,
event_txs: RefCell<Vec<mpsc::UnboundedSender<Event>>>,
}
impl WebSocketRpcSession {
pub fn new(client: Rc<WebSocketRpcClient>) -> Rc<Self> {
let this = Rc::new(Self {
client,
state: ObservableCell::new(SessionState::Uninitialized),
was_connected: Rc::new(Cell::new(false)),
event_txs: RefCell::default(),
});
this.spawn_state_watcher();
this.spawn_connection_loss_watcher();
this.spawn_close_watcher();
this.spawn_server_msg_listener();
this
}
async fn inner_connect(self: Rc<Self>) -> Result<(), Traced<SessionError>> {
use SessionError as E;
use SessionState as S;
match self.state.get() {
S::Connecting(_) | S::Authorizing(_) | S::Opened { .. } => {}
S::Initialized(info) | S::Lost(_, info) => {
self.state.set(S::Connecting(info));
}
S::Uninitialized => {
return Err(tracerr::new!(E::NoCredentials));
}
S::Finished(reason) => {
return Err(tracerr::new!(E::SessionFinished(reason)));
}
}
let mut state_updates_stream = self.state.subscribe();
while let Some(state) = state_updates_stream.next().await {
match state {
S::Opened { .. } => return Ok(()),
S::Initialized(_) => {
return Err(tracerr::new!(E::NewConnectionInfo));
}
S::Lost(reason, _) => {
return Err(tracerr::new!(E::ConnectionLost(reason)));
}
S::Uninitialized => {
return Err(tracerr::new!(E::AuthorizationFailed));
}
S::Finished(reason) => {
return Err(tracerr::new!(E::SessionFinished(reason)));
}
S::Connecting(_) | S::Authorizing(_) => {}
}
}
Err(tracerr::new!(E::SessionUnexpectedlyDropped))
}
fn spawn_state_watcher(self: &Rc<Self>) {
use SessionState as S;
let mut state_updates = self.state.subscribe();
let weak_this = Rc::downgrade(self);
platform::spawn(async move {
let capabilities = platform::get_capabilities().await;
while let Some(state) = state_updates.next().await {
let this = upgrade_or_break!(weak_this);
match state {
S::Connecting(info) => match Rc::clone(&this.client)
.connect(info.url.clone())
.await
{
Ok(()) => {
this.state.set(S::Authorizing(info));
}
Err(e) => {
this.state.set(S::Lost(
ConnectionLostReason::ConnectError(e),
info,
));
}
},
S::Authorizing(info) => {
this.client.join_room(
info.room_id.clone(),
info.member_id.clone(),
info.credential.clone(),
capabilities.clone(),
);
}
S::Uninitialized
| S::Initialized(_)
| S::Lost(..)
| S::Opened { .. }
| S::Finished(_) => {}
}
}
});
}
fn spawn_connection_loss_watcher(self: &Rc<Self>) {
use SessionState as S;
let mut client_on_connection_loss = self.client.on_connection_loss();
let weak_this = Rc::downgrade(self);
platform::spawn(async move {
while let Some(reason) = client_on_connection_loss.next().await {
let this = upgrade_or_break!(weak_this);
let state = this.state.get();
if matches!(state, S::Opened { .. }) {
this.was_connected.set(true);
}
match state {
S::Connecting(info)
| S::Authorizing(info)
| S::Opened { info, .. } => {
this.state.set(S::Lost(
ConnectionLostReason::Lost(reason),
info,
));
}
S::Uninitialized
| S::Initialized(_)
| S::Lost(_, _)
| S::Finished(_) => {}
}
}
});
}
fn spawn_close_watcher(self: &Rc<Self>) {
let on_normal_close = self.client.on_normal_close();
let weak_this = Rc::downgrade(self);
platform::spawn(async move {
let reason = on_normal_close.await.unwrap_or_else(|_| {
ClientDisconnect::RpcClientUnexpectedlyDropped.into()
});
if let Some(this) = weak_this.upgrade() {
this.state.set(SessionState::Finished(reason));
}
});
}
fn spawn_server_msg_listener(self: &Rc<Self>) {
let mut server_msg_rx = self.client.subscribe();
let weak_this = Rc::downgrade(self);
platform::spawn(async move {
while let Some(msg) = server_msg_rx.next().await {
let this = upgrade_or_break!(weak_this);
msg.dispatch_with(this.as_ref());
}
});
}
}
#[async_trait(?Send)]
impl RpcSession for WebSocketRpcSession {
async fn connect(
self: Rc<Self>,
connection_info: ConnectionInfo,
) -> Result<(), Traced<SessionError>> {
use SessionState as S;
match self.state.get() {
S::Uninitialized | S::Initialized(_) | S::Lost(_, _) => {
self.state.set(S::Initialized(Rc::new(connection_info)));
}
S::Finished(reason) => {
return Err(tracerr::new!(SessionError::SessionFinished(
reason
)));
}
S::Connecting(info) => {
if info.as_ref() != &connection_info {
self.state.set(S::Initialized(Rc::new(connection_info)));
}
}
S::Authorizing(info) | S::Opened { info, .. } => {
if info.as_ref() != &connection_info {
unimplemented!(
"Changing `ConnectionInfo` with active or pending \
authorization is not supported",
);
}
}
}
self.inner_connect().await.map_err(tracerr::map_from_and_wrap!())?;
Ok(())
}
async fn reconnect(self: Rc<Self>) -> Result<(), Traced<SessionError>> {
self.inner_connect().await.map_err(tracerr::map_from_and_wrap!())?;
Ok(())
}
fn subscribe(&self) -> LocalBoxStream<'static, Event> {
let (tx, rx) = mpsc::unbounded();
self.event_txs.borrow_mut().push(tx);
Box::pin(rx)
}
fn send_command(&self, command: Command) {
if let SessionState::Opened { info, .. } = self.state.get() {
self.client.send_command(info.room_id.clone(), command);
}
}
fn on_normal_close(&self) -> LocalBoxFuture<'static, CloseReason> {
let mut state_stream = self
.state
.subscribe()
.filter_map(async |s| {
if let SessionState::Finished(reason) = s {
Some(reason)
} else {
None
}
})
.boxed_local();
Box::pin(async move {
state_stream.next().await.unwrap_or_else(|| {
ClientDisconnect::SessionUnexpectedlyDropped.into()
})
})
}
fn close_with_reason(&self, close_reason: ClientDisconnect) {
if let SessionState::Opened { info, .. } = self.state.get() {
self.client
.leave_room(info.room_id.clone(), info.member_id.clone());
}
self.client.set_close_reason(close_reason);
self.state.set(SessionState::Finished(close_reason.into()));
}
fn on_connection_loss(&self) -> LocalBoxStream<'static, ()> {
let was_connected = Rc::clone(&self.was_connected);
self.state
.subscribe()
.filter_map(move |state| {
if matches!(state, SessionState::Lost(_, _))
&& was_connected.get()
{
future::ready(Some(()))
} else {
future::ready(None)
}
})
.boxed_local()
}
fn on_reconnected(&self) -> LocalBoxStream<'static, ()> {
self.state
.subscribe()
.filter_map(async move |current_state| {
matches!(
current_state,
SessionState::Opened { is_reconnect: true, .. }
)
.then_some(())
})
.boxed_local()
}
async fn network_changed(
self: Rc<Self>,
) -> Result<(), Traced<SessionError>> {
use SessionError as E;
use SessionState as S;
match self.state.get() {
S::Connecting(_)
| S::Authorizing(_)
| S::Lost(..)
| S::Initialized(_) => {}
S::Opened { info, .. } => {
self.was_connected.set(true);
self.client.close_for_reconnection();
self.state.set(S::Connecting(info));
}
S::Uninitialized => {
return Err(tracerr::new!(E::NoCredentials));
}
S::Finished(reason) => {
return Err(tracerr::new!(E::SessionFinished(reason)));
}
}
self.inner_connect().await.map_err(tracerr::map_from_and_wrap!())
}
}
impl RpcEventHandler for WebSocketRpcSession {
type Output = ();
fn on_joined_room(
&self,
room_id: RoomId,
member_id: MemberId,
is_reconnect: bool,
) {
let state = self.state.get();
if let SessionState::Authorizing(info) = state {
if info.room_id == room_id && info.member_id == member_id {
self.state.set(SessionState::Opened { info, is_reconnect });
}
}
}
fn on_left_room(&self, room_id: RoomId, close_reason: CloseReason) {
let state = self.state.get();
match &state {
SessionState::Opened { info, .. }
| SessionState::Authorizing(info) => {
if info.room_id != room_id {
return;
}
}
SessionState::Uninitialized
| SessionState::Initialized(_)
| SessionState::Connecting(_)
| SessionState::Lost(..)
| SessionState::Finished(_) => return,
}
match state {
SessionState::Opened { .. } => {
self.state.set(SessionState::Finished(close_reason));
}
SessionState::Authorizing(_) => {
self.state.set(SessionState::Uninitialized);
}
SessionState::Uninitialized
| SessionState::Initialized(_)
| SessionState::Connecting(_)
| SessionState::Lost(..)
| SessionState::Finished(_) => {}
}
}
fn on_event(&self, room_id: RoomId, event: Event) {
if let SessionState::Opened { info, .. } = self.state.get() {
if info.room_id == room_id {
self.event_txs
.borrow_mut()
.retain(|tx| tx.unbounded_send(event.clone()).is_ok());
}
}
}
}
#[derive(Clone, Debug, Derivative)]
#[derivative(PartialEq)]
pub enum SessionState {
Uninitialized,
Initialized(Rc<ConnectionInfo>),
Connecting(Rc<ConnectionInfo>),
Authorizing(Rc<ConnectionInfo>),
Lost(
#[derivative(PartialEq = "ignore")] ConnectionLostReason,
Rc<ConnectionInfo>,
),
Opened {
info: Rc<ConnectionInfo>,
is_reconnect: bool,
},
Finished(CloseReason),
}