1use std::ops::DerefMut;
4use std::pin::pin;
5use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6use std::time::Duration;
7
8use async_tungstenite::tokio::ConnectStream;
9use async_tungstenite::{WebSocketReceiver, WebSocketSender, WebSocketStream};
10use futures::StreamExt;
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13use tokio::sync::Mutex;
14use tracing::{debug, instrument, trace};
15use tungstenite::http::Uri;
16
17use crate::message::{ChannelMsg, Message};
18use crate::{Builder, Error, Map};
19
20pub type Id = usize;
22
23type Sender = WebSocketSender<ConnectStream>;
24type Receiver = WebSocketReceiver<ConnectStream>;
25
26#[derive(Debug)]
27struct Reader {
28 heartbeat: tokio::time::Interval,
29 receiver: Receiver,
30}
31
32#[derive(Debug)]
34pub struct Client {
35 join_id: AtomicUsize,
36 msg_id: AtomicUsize,
37 sent: AtomicBool,
38 writer: Mutex<Sender>,
39 reader: Mutex<Reader>,
40}
41
42impl Client {
43 pub(crate) fn new(connection: WebSocketStream<ConnectStream>, heartbeat: Duration) -> Self {
44 let (writer, reader) = connection.split();
45 Self {
46 join_id: AtomicUsize::new(1),
47 msg_id: AtomicUsize::new(1),
48 sent: AtomicBool::new(false),
49 writer: Mutex::new(writer),
50 reader: Mutex::new(Reader {
51 heartbeat: tokio::time::interval(heartbeat),
52 receiver: reader,
53 }),
54 }
55 }
56
57 fn next_id(&self) -> usize {
58 self.msg_id.fetch_add(1, Ordering::AcqRel)
59 }
60
61 pub fn builder(uri: Uri) -> Result<Builder, Error> {
63 Builder::new(uri)
64 }
65
66 pub fn set_join_id(&self, join_id: usize) {
68 self.join_id.store(join_id, Ordering::Release);
69 }
70
71 pub async fn join(&self, topic: &str) -> Result<Id, Error> {
73 self.join_with_payload(topic, Map::default()).await
74 }
75
76 #[instrument(skip(self, payload))]
78 pub async fn join_with_payload<P>(&self, topic: &str, payload: P) -> Result<Id, Error>
79 where
80 P: Serialize,
81 {
82 let join_id = self.join_id.load(Ordering::Acquire);
83 let msg_id = self.next_id();
84
85 let msg = ChannelMsg::new(Some(join_id), Some(msg_id), topic, "phx_join", payload);
86
87 debug!(msg_id, "joining topic");
88
89 self.write_msg(msg).await?;
90
91 trace!(msg_id, "topic joined");
92
93 Ok(msg_id)
94 }
95
96 #[instrument(skip(self))]
98 pub async fn leave(&self, topic: &str) -> Result<Id, Error> {
99 let join_id = self.join_id.load(Ordering::Relaxed);
100 let msg_id = self.next_id();
101
102 let msg = ChannelMsg::new(
103 Some(join_id),
104 Some(msg_id),
105 topic,
106 "phx_leave",
107 Map::default(),
108 );
109
110 debug!(msg_id, "leaving topic");
111
112 self.write_msg(msg).await?;
113
114 trace!(msg_id, "topic left");
115
116 Ok(msg_id)
117 }
118
119 #[instrument(skip(self, payload))]
121 pub async fn send<P>(&self, topic: &str, event: &str, payload: P) -> Result<Id, Error>
122 where
123 P: Serialize,
124 {
125 let join_id = self.join_id.load(Ordering::Relaxed);
126 let msg_id = self.next_id();
127
128 let msg = ChannelMsg::new(Some(join_id), Some(msg_id), topic, event, payload);
129
130 debug!(msg_id, "sending event");
131
132 self.write_msg(msg).await?;
133
134 trace!(msg_id, "event sent");
135
136 Ok(msg_id)
137 }
138
139 #[instrument(skip_all)]
140 async fn write_msg<P>(&self, msg: ChannelMsg<'_, P>) -> Result<(), Error>
141 where
142 P: Serialize,
143 {
144 let msg_json = serde_json::to_string(&msg).map_err(Error::Serialize)?;
145
146 trace!("writing on socket");
147
148 self.writer
149 .lock()
150 .await
151 .send(tungstenite::Message::Text(msg_json.into()))
152 .await
153 .map_err(Box::new)
154 .map_err(|err| Error::Send {
155 msg: msg.into_err(),
156 backtrace: err,
157 })?;
158
159 trace!("update sent flag");
160
161 self.sent.store(true, Ordering::Release);
162
163 Ok(())
164 }
165
166 #[instrument(skip(self))]
168 pub async fn recv<P>(&self) -> Result<Message<P>, Error>
169 where
170 P: DeserializeOwned,
171 {
172 trace!("waiting for next message");
173
174 let msg = self.next_msg().await?;
175
176 trace!(%msg, "WebSocket message received");
177
178 msg.into_text()
179 .map_err(Box::new)
180 .map_err(Error::WebSocketMessageType)
181 .and_then(|txt| {
182 serde_json::from_str::<ChannelMsg<P>>(txt.as_str()).map_err(Error::Deserialize)
183 })
184 .map(|msg| {
185 let msg = Message::from(msg);
186
187 debug!(message = msg.info(), "message received");
188
189 msg
190 })
191 }
192
193 #[instrument(skip(self))]
194 async fn next_msg(&self) -> Result<tungstenite::Message, Error> {
195 trace!("waiting for reader lock");
196 let mut reader = self.reader.lock().await;
197 let reader = reader.deref_mut();
198
199 let mut receive = reader.receiver.next();
200
201 loop {
202 trace!("waiting for next event or heartbeat");
203 match futures::future::select(pin!(reader.heartbeat.tick()), pin!(&mut receive)).await {
204 futures::future::Either::Left((_instant, _next)) => {
205 trace!("heartbeat interval");
206 self.check_and_send_heartbeat().await?;
207 }
208 futures::future::Either::Right((None, _)) => {
209 debug!("WebSocket disconnected");
210
211 return Err(Error::Disconnected);
212 }
213 futures::future::Either::Right((Some(res), _)) => {
214 trace!("next event");
215
216 return res.map_err(Box::new).map_err(Error::Recv);
217 }
218 };
219 }
220 }
221
222 #[instrument(skip(self))]
223 async fn check_and_send_heartbeat(&self) -> Result<(), Error> {
224 let val = self
225 .sent
226 .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire);
227
228 trace!(sent_flag = ?val, "heartbeat sent flag");
229
230 match val {
231 Ok(val) => {
232 debug_assert!(val);
233 }
234 Err(val) => {
235 debug_assert!(!val);
236
237 let id = self.next_id();
238
239 let heartbeat =
240 ChannelMsg::new(None, Some(id), "phoenix", "heartbeat", Map::default());
241
242 debug!(id, "sending heartbeat");
243
244 self.write_msg(heartbeat).await?;
245 }
246 }
247
248 Ok(())
249 }
250}