phoenix_chan/
client.rs

1//! Client for the Phoenix channel
2
3use std::borrow::Cow;
4use std::ops::DerefMut;
5use std::pin::pin;
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::time::Duration;
8
9use async_tungstenite::tokio::ConnectStream;
10use async_tungstenite::WebSocketStream;
11use futures::stream::{SplitSink, SplitStream};
12use futures::{SinkExt, StreamExt};
13use serde::de::DeserializeOwned;
14use serde::Serialize;
15use tokio::sync::Mutex;
16use tungstenite::http::Uri;
17
18use crate::message::{ChannelMsg, Message};
19use crate::{Builder, Error, Map};
20
21/// Id to identify the response of a message sent by the client.
22pub type Id = usize;
23
24type Sender = SplitSink<WebSocketStream<ConnectStream>, tungstenite::Message>;
25type Receiver = SplitStream<WebSocketStream<ConnectStream>>;
26
27#[derive(Debug)]
28struct Reader {
29    heartbeat: tokio::time::Interval,
30    receiver: Receiver,
31}
32
33/// Connection for the Phoenix channel
34#[derive(Debug)]
35pub struct Client {
36    msg_id: AtomicUsize,
37    sent: AtomicBool,
38    writer: Mutex<Sender>,
39    reader: Mutex<Reader>,
40}
41
42impl Client {
43    pub(crate) fn new(connection: WebSocketStream<ConnectStream>, heartbeat: Duration) -> Self {
44        let (writer, reader) = connection.split();
45        Self {
46            msg_id: AtomicUsize::new(0),
47            sent: AtomicBool::new(false),
48            writer: Mutex::new(writer),
49            reader: Mutex::new(Reader {
50                heartbeat: tokio::time::interval(heartbeat),
51                receiver: reader,
52            }),
53        }
54    }
55
56    fn next_id(&self) -> usize {
57        self.msg_id.fetch_add(1, Ordering::AcqRel)
58    }
59
60    /// Returns a builder to configure the client.
61    pub fn builder(uri: Uri) -> Builder {
62        Builder::new(uri)
63    }
64
65    /// Joins a channel.
66    pub async fn join(&self, topic: &str) -> Result<Id, Error> {
67        self.join_with_payload(topic, Map::default()).await
68    }
69
70    /// Joins a channel with additional parameters.
71    pub async fn join_with_payload<P>(&self, topic: &str, payload: P) -> Result<Id, Error>
72    where
73        P: Serialize,
74    {
75        let id = self.next_id();
76
77        let msg = ChannelMsg {
78            join_reference: Some(Cow::Owned(id.to_string())),
79            message_reference: Cow::Owned(id.to_string()),
80            topic_name: Cow::Borrowed(topic),
81            event_name: Cow::Borrowed("phx_join"),
82            payload,
83        };
84
85        self.write_msg(msg).await?;
86
87        Ok(id)
88    }
89
90    /// Leaves a channel.
91    pub async fn leave(&self, topic: &str) -> Result<Id, Error> {
92        self.send(topic, "phx_leave", Map::default()).await
93    }
94
95    /// Sends an event on a topic
96    pub async fn send<P>(&self, topic: &str, event: &str, payload: P) -> Result<Id, Error>
97    where
98        P: Serialize,
99    {
100        let id = self.next_id();
101
102        let msg = ChannelMsg {
103            join_reference: None,
104            message_reference: Cow::Owned(id.to_string()),
105            topic_name: Cow::Borrowed(topic),
106            event_name: Cow::Borrowed(event),
107            payload,
108        };
109
110        self.write_msg(msg).await?;
111
112        Ok(id)
113    }
114
115    async fn write_msg<P>(&self, msg: ChannelMsg<'_, P>) -> Result<(), Error>
116    where
117        P: Serialize,
118    {
119        let msg_json = serde_json::to_string(&msg).map_err(Error::Serialize)?;
120
121        self.writer
122            .lock()
123            .await
124            .send(tungstenite::Message::Text(msg_json.into()))
125            .await
126            .map_err(|err| Error::Send {
127                msg: msg.into_err(),
128                backtrace: err,
129            })?;
130
131        self.sent.store(true, Ordering::Release);
132
133        Ok(())
134    }
135
136    /// Returns the next message in any channel.
137    pub async fn recv<P>(&self) -> Result<Message<P>, Error>
138    where
139        P: DeserializeOwned,
140    {
141        let msg = self.next_msg().await?;
142
143        msg.into_text()
144            .map_err(Error::WebSocketMessageType)
145            .and_then(|txt| {
146                serde_json::from_str::<ChannelMsg<P>>(txt.as_str()).map_err(Error::Deserialize)
147            })
148            .map(Message::from)
149    }
150
151    async fn next_msg(&self) -> Result<tungstenite::Message, Error> {
152        let mut reader = self.reader.lock().await;
153        let reader = reader.deref_mut();
154
155        let mut receive = reader.receiver.next();
156
157        let next = loop {
158            match futures::future::select(pin!(reader.heartbeat.tick()), pin!(&mut receive)).await {
159                futures::future::Either::Left((_instant, _next)) => {
160                    self.check_and_send_heartbeat().await?;
161                }
162                futures::future::Either::Right((next, _)) => break next,
163            };
164        };
165
166        next.ok_or(Error::Disconnected)?.map_err(Error::Recv)
167    }
168
169    async fn check_and_send_heartbeat(&self) -> Result<(), Error> {
170        let val = self
171            .sent
172            .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire);
173
174        match val {
175            Ok(val) => {
176                debug_assert!(val);
177            }
178            Err(val) => {
179                debug_assert!(!val);
180
181                let heartbeat = ChannelMsg {
182                    join_reference: None,
183                    message_reference: Cow::Owned(self.next_id().to_string()),
184                    topic_name: Cow::Borrowed("phoenix"),
185                    event_name: Cow::Borrowed("heartbeat"),
186                    payload: Map::default(),
187                };
188
189                self.write_msg(heartbeat).await?;
190            }
191        }
192
193        Ok(())
194    }
195}