Skip to main content

phoenix_chan/
client.rs

1//! Client for the Phoenix channel
2
3use std::ops::DerefMut;
4use std::pin::pin;
5use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6use std::time::Duration;
7
8use async_tungstenite::tokio::ConnectStream;
9use async_tungstenite::{WebSocketReceiver, WebSocketSender, WebSocketStream};
10use futures::StreamExt;
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13use tokio::sync::Mutex;
14use tracing::{debug, instrument, trace};
15use tungstenite::http::Uri;
16
17use crate::message::{ChannelMsg, Message};
18use crate::{Builder, Error, Map};
19
20/// Id to identify the response of a message sent by the client.
21pub type Id = usize;
22
23type Sender = WebSocketSender<ConnectStream>;
24type Receiver = WebSocketReceiver<ConnectStream>;
25
26#[derive(Debug)]
27struct Reader {
28    heartbeat: tokio::time::Interval,
29    receiver: Receiver,
30}
31
32/// Connection for the Phoenix channel
33#[derive(Debug)]
34pub struct Client {
35    join_id: AtomicUsize,
36    msg_id: AtomicUsize,
37    sent: AtomicBool,
38    writer: Mutex<Sender>,
39    reader: Mutex<Reader>,
40}
41
42impl Client {
43    pub(crate) fn new(connection: WebSocketStream<ConnectStream>, heartbeat: Duration) -> Self {
44        let (writer, reader) = connection.split();
45        Self {
46            join_id: AtomicUsize::new(1),
47            msg_id: AtomicUsize::new(1),
48            sent: AtomicBool::new(false),
49            writer: Mutex::new(writer),
50            reader: Mutex::new(Reader {
51                heartbeat: tokio::time::interval(heartbeat),
52                receiver: reader,
53            }),
54        }
55    }
56
57    fn next_id(&self) -> usize {
58        self.msg_id.fetch_add(1, Ordering::AcqRel)
59    }
60
61    /// Returns a builder to configure the client.
62    pub fn builder(uri: Uri) -> Result<Builder, Error> {
63        Builder::new(uri)
64    }
65
66    /// Sets the join id.
67    pub fn set_join_id(&self, join_id: usize) {
68        self.join_id.store(join_id, Ordering::Release);
69    }
70
71    /// Joins a channel.
72    pub async fn join(&self, topic: &str) -> Result<Id, Error> {
73        self.join_with_payload(topic, Map::default()).await
74    }
75
76    /// Joins a channel with additional parameters.
77    #[instrument(skip(self, payload))]
78    pub async fn join_with_payload<P>(&self, topic: &str, payload: P) -> Result<Id, Error>
79    where
80        P: Serialize,
81    {
82        let join_id = self.join_id.load(Ordering::Acquire);
83        let msg_id = self.next_id();
84
85        let msg = ChannelMsg::new(Some(join_id), Some(msg_id), topic, "phx_join", payload);
86
87        debug!(msg_id, "joining topic");
88
89        self.write_msg(msg).await?;
90
91        trace!(msg_id, "topic joined");
92
93        Ok(msg_id)
94    }
95
96    /// Leaves a channel.
97    #[instrument(skip(self))]
98    pub async fn leave(&self, topic: &str) -> Result<Id, Error> {
99        let join_id = self.join_id.load(Ordering::Relaxed);
100        let msg_id = self.next_id();
101
102        let msg = ChannelMsg::new(
103            Some(join_id),
104            Some(msg_id),
105            topic,
106            "phx_leave",
107            Map::default(),
108        );
109
110        debug!(msg_id, "leaving topic");
111
112        self.write_msg(msg).await?;
113
114        trace!(msg_id, "topic left");
115
116        Ok(msg_id)
117    }
118
119    /// Sends an event on a topic
120    #[instrument(skip(self, payload))]
121    pub async fn send<P>(&self, topic: &str, event: &str, payload: P) -> Result<Id, Error>
122    where
123        P: Serialize,
124    {
125        let join_id = self.join_id.load(Ordering::Relaxed);
126        let msg_id = self.next_id();
127
128        let msg = ChannelMsg::new(Some(join_id), Some(msg_id), topic, event, payload);
129
130        debug!(msg_id, "sending event");
131
132        self.write_msg(msg).await?;
133
134        trace!(msg_id, "event sent");
135
136        Ok(msg_id)
137    }
138
139    #[instrument(skip_all)]
140    async fn write_msg<P>(&self, msg: ChannelMsg<'_, P>) -> Result<(), Error>
141    where
142        P: Serialize,
143    {
144        let msg_json = serde_json::to_string(&msg).map_err(Error::Serialize)?;
145
146        trace!("writing on socket");
147
148        self.writer
149            .lock()
150            .await
151            .send(tungstenite::Message::Text(msg_json.into()))
152            .await
153            .map_err(Box::new)
154            .map_err(|err| Error::Send {
155                msg: msg.into_err(),
156                backtrace: err,
157            })?;
158
159        trace!("update sent flag");
160
161        self.sent.store(true, Ordering::Release);
162
163        Ok(())
164    }
165
166    /// Returns the next message in any channel.
167    #[instrument(skip(self))]
168    pub async fn recv<P>(&self) -> Result<Message<P>, Error>
169    where
170        P: DeserializeOwned,
171    {
172        trace!("waiting for next message");
173
174        let msg = self.next_msg().await?;
175
176        trace!(%msg, "WebSocket message received");
177
178        msg.into_text()
179            .map_err(Box::new)
180            .map_err(Error::WebSocketMessageType)
181            .and_then(|txt| {
182                serde_json::from_str::<ChannelMsg<P>>(txt.as_str()).map_err(Error::Deserialize)
183            })
184            .map(|msg| {
185                let msg = Message::from(msg);
186
187                debug!(message = msg.info(), "message received");
188
189                msg
190            })
191    }
192
193    #[instrument(skip(self))]
194    async fn next_msg(&self) -> Result<tungstenite::Message, Error> {
195        trace!("waiting for reader lock");
196        let mut reader = self.reader.lock().await;
197        let reader = reader.deref_mut();
198
199        let mut receive = reader.receiver.next();
200
201        loop {
202            trace!("waiting for next event or heartbeat");
203            match futures::future::select(pin!(reader.heartbeat.tick()), pin!(&mut receive)).await {
204                futures::future::Either::Left((_instant, _next)) => {
205                    trace!("heartbeat interval");
206                    self.check_and_send_heartbeat().await?;
207                }
208                futures::future::Either::Right((None, _)) => {
209                    debug!("WebSocket disconnected");
210
211                    return Err(Error::Disconnected);
212                }
213                futures::future::Either::Right((Some(res), _)) => {
214                    trace!("next event");
215
216                    return res.map_err(Box::new).map_err(Error::Recv);
217                }
218            };
219        }
220    }
221
222    #[instrument(skip(self))]
223    async fn check_and_send_heartbeat(&self) -> Result<(), Error> {
224        let val = self
225            .sent
226            .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire);
227
228        trace!(sent_flag = ?val, "heartbeat sent flag");
229
230        match val {
231            Ok(val) => {
232                debug_assert!(val);
233            }
234            Err(val) => {
235                debug_assert!(!val);
236
237                let id = self.next_id();
238
239                let heartbeat =
240                    ChannelMsg::new(None, Some(id), "phoenix", "heartbeat", Map::default());
241
242                debug!(id, "sending heartbeat");
243
244                self.write_msg(heartbeat).await?;
245            }
246        }
247
248        Ok(())
249    }
250}