use std::ops::DerefMut;
use std::pin::pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use async_tungstenite::tokio::ConnectStream;
use async_tungstenite::{WebSocketReceiver, WebSocketSender, WebSocketStream};
use futures::StreamExt;
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::Mutex;
use tracing::{debug, instrument, trace};
use tungstenite::http::Uri;
use crate::message::{ChannelMsg, Message};
use crate::{Builder, Error, Map};
pub type Id = usize;
type Sender = WebSocketSender<ConnectStream>;
type Receiver = WebSocketReceiver<ConnectStream>;
#[derive(Debug)]
struct Reader {
heartbeat: tokio::time::Interval,
receiver: Receiver,
}
#[derive(Debug)]
pub struct Client {
join_id: AtomicUsize,
msg_id: AtomicUsize,
sent: AtomicBool,
writer: Mutex<Sender>,
reader: Mutex<Reader>,
}
impl Client {
pub(crate) fn new(connection: WebSocketStream<ConnectStream>, heartbeat: Duration) -> Self {
let (writer, reader) = connection.split();
Self {
join_id: AtomicUsize::new(1),
msg_id: AtomicUsize::new(1),
sent: AtomicBool::new(false),
writer: Mutex::new(writer),
reader: Mutex::new(Reader {
heartbeat: tokio::time::interval(heartbeat),
receiver: reader,
}),
}
}
fn next_id(&self) -> usize {
self.msg_id.fetch_add(1, Ordering::AcqRel)
}
pub fn builder(uri: Uri) -> Result<Builder, Error> {
Builder::new(uri)
}
pub fn set_join_id(&self, join_id: usize) {
self.join_id.store(join_id, Ordering::Release);
}
pub async fn join(&self, topic: &str) -> Result<Id, Error> {
self.join_with_payload(topic, Map::default()).await
}
#[instrument(skip(self, payload))]
pub async fn join_with_payload<P>(&self, topic: &str, payload: P) -> Result<Id, Error>
where
P: Serialize,
{
let join_id = self.join_id.load(Ordering::Acquire);
let msg_id = self.next_id();
let msg = ChannelMsg::new(Some(join_id), Some(msg_id), topic, "phx_join", payload);
debug!(msg_id, "joining topic");
self.write_msg(msg).await?;
trace!(msg_id, "topic joined");
Ok(msg_id)
}
#[instrument(skip(self))]
pub async fn leave(&self, topic: &str) -> Result<Id, Error> {
let join_id = self.join_id.load(Ordering::Relaxed);
let msg_id = self.next_id();
let msg = ChannelMsg::new(
Some(join_id),
Some(msg_id),
topic,
"phx_leave",
Map::default(),
);
debug!(msg_id, "leaving topic");
self.write_msg(msg).await?;
trace!(msg_id, "topic left");
Ok(msg_id)
}
#[instrument(skip(self, payload))]
pub async fn send<P>(&self, topic: &str, event: &str, payload: P) -> Result<Id, Error>
where
P: Serialize,
{
let join_id = self.join_id.load(Ordering::Relaxed);
let msg_id = self.next_id();
let msg = ChannelMsg::new(Some(join_id), Some(msg_id), topic, event, payload);
debug!(msg_id, "sending event");
self.write_msg(msg).await?;
trace!(msg_id, "event sent");
Ok(msg_id)
}
#[instrument(skip_all)]
async fn write_msg<P>(&self, msg: ChannelMsg<'_, P>) -> Result<(), Error>
where
P: Serialize,
{
let msg_json = serde_json::to_string(&msg).map_err(Error::Serialize)?;
trace!("writing on socket");
self.writer
.lock()
.await
.send(tungstenite::Message::Text(msg_json.into()))
.await
.map_err(Box::new)
.map_err(|err| Error::Send {
msg: msg.into_err(),
backtrace: err,
})?;
trace!("update sent flag");
self.sent.store(true, Ordering::Release);
Ok(())
}
#[instrument(skip(self))]
pub async fn recv<P>(&self) -> Result<Message<P>, Error>
where
P: DeserializeOwned,
{
trace!("waiting for next message");
let msg = self.next_msg().await?;
trace!(%msg, "WebSocket message received");
msg.into_text()
.map_err(Box::new)
.map_err(Error::WebSocketMessageType)
.and_then(|txt| {
serde_json::from_str::<ChannelMsg<P>>(txt.as_str()).map_err(Error::Deserialize)
})
.map(|msg| {
let msg = Message::from(msg);
debug!(message = msg.info(), "message received");
msg
})
}
#[instrument(skip(self))]
async fn next_msg(&self) -> Result<tungstenite::Message, Error> {
trace!("waiting for reader lock");
let mut reader = self.reader.lock().await;
let reader = reader.deref_mut();
let mut receive = reader.receiver.next();
loop {
trace!("waiting for next event or heartbeat");
match futures::future::select(pin!(reader.heartbeat.tick()), pin!(&mut receive)).await {
futures::future::Either::Left((_instant, _next)) => {
trace!("heartbeat interval");
self.check_and_send_heartbeat().await?;
}
futures::future::Either::Right((None, _)) => {
debug!("WebSocket disconnected");
return Err(Error::Disconnected);
}
futures::future::Either::Right((Some(res), _)) => {
trace!("next event");
return res.map_err(Box::new).map_err(Error::Recv);
}
};
}
}
#[instrument(skip(self))]
async fn check_and_send_heartbeat(&self) -> Result<(), Error> {
let val = self
.sent
.compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire);
trace!(sent_flag = ?val, "heartbeat sent flag");
match val {
Ok(val) => {
debug_assert!(val);
}
Err(val) => {
debug_assert!(!val);
let id = self.next_id();
let heartbeat =
ChannelMsg::new(None, Some(id), "phoenix", "heartbeat", Map::default());
debug!(id, "sending heartbeat");
self.write_msg(heartbeat).await?;
}
}
Ok(())
}
}