1use std::borrow::Cow;
4use std::ops::DerefMut;
5use std::pin::pin;
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::time::Duration;
8
9use async_tungstenite::tokio::ConnectStream;
10use async_tungstenite::WebSocketStream;
11use futures::stream::{SplitSink, SplitStream};
12use futures::{SinkExt, StreamExt};
13use serde::de::DeserializeOwned;
14use serde::Serialize;
15use tokio::sync::Mutex;
16use tungstenite::http::Uri;
17
18use crate::message::{ChannelMsg, Message};
19use crate::{Builder, Error, Map};
20
21pub type Id = usize;
23
24type Sender = SplitSink<WebSocketStream<ConnectStream>, tungstenite::Message>;
25type Receiver = SplitStream<WebSocketStream<ConnectStream>>;
26
27#[derive(Debug)]
28struct Reader {
29 heartbeat: tokio::time::Interval,
30 receiver: Receiver,
31}
32
33#[derive(Debug)]
35pub struct Client {
36 msg_id: AtomicUsize,
37 sent: AtomicBool,
38 writer: Mutex<Sender>,
39 reader: Mutex<Reader>,
40}
41
42impl Client {
43 pub(crate) fn new(connection: WebSocketStream<ConnectStream>, heartbeat: Duration) -> Self {
44 let (writer, reader) = connection.split();
45 Self {
46 msg_id: AtomicUsize::new(0),
47 sent: AtomicBool::new(false),
48 writer: Mutex::new(writer),
49 reader: Mutex::new(Reader {
50 heartbeat: tokio::time::interval(heartbeat),
51 receiver: reader,
52 }),
53 }
54 }
55
56 fn next_id(&self) -> usize {
57 self.msg_id.fetch_add(1, Ordering::AcqRel)
58 }
59
60 pub fn builder(uri: Uri) -> Builder {
62 Builder::new(uri)
63 }
64
65 pub async fn join(&self, topic: &str) -> Result<Id, Error> {
67 self.join_with_payload(topic, Map::default()).await
68 }
69
70 pub async fn join_with_payload<P>(&self, topic: &str, payload: P) -> Result<Id, Error>
72 where
73 P: Serialize,
74 {
75 let id = self.next_id();
76
77 let msg = ChannelMsg {
78 join_reference: Some(Cow::Owned(id.to_string())),
79 message_reference: Cow::Owned(id.to_string()),
80 topic_name: Cow::Borrowed(topic),
81 event_name: Cow::Borrowed("phx_join"),
82 payload,
83 };
84
85 self.write_msg(msg).await?;
86
87 Ok(id)
88 }
89
90 pub async fn leave(&self, topic: &str) -> Result<Id, Error> {
92 self.send(topic, "phx_leave", Map::default()).await
93 }
94
95 pub async fn send<P>(&self, topic: &str, event: &str, payload: P) -> Result<Id, Error>
97 where
98 P: Serialize,
99 {
100 let id = self.next_id();
101
102 let msg = ChannelMsg {
103 join_reference: None,
104 message_reference: Cow::Owned(id.to_string()),
105 topic_name: Cow::Borrowed(topic),
106 event_name: Cow::Borrowed(event),
107 payload,
108 };
109
110 self.write_msg(msg).await?;
111
112 Ok(id)
113 }
114
115 async fn write_msg<P>(&self, msg: ChannelMsg<'_, P>) -> Result<(), Error>
116 where
117 P: Serialize,
118 {
119 let msg_json = serde_json::to_string(&msg).map_err(Error::Serialize)?;
120
121 self.writer
122 .lock()
123 .await
124 .send(tungstenite::Message::Text(msg_json.into()))
125 .await
126 .map_err(|err| Error::Send {
127 msg: msg.into_err(),
128 backtrace: err,
129 })?;
130
131 self.sent.store(true, Ordering::Release);
132
133 Ok(())
134 }
135
136 pub async fn recv<P>(&self) -> Result<Message<P>, Error>
138 where
139 P: DeserializeOwned,
140 {
141 let msg = self.next_msg().await?;
142
143 msg.into_text()
144 .map_err(Error::WebSocketMessageType)
145 .and_then(|txt| {
146 serde_json::from_str::<ChannelMsg<P>>(txt.as_str()).map_err(Error::Deserialize)
147 })
148 .map(Message::from)
149 }
150
151 async fn next_msg(&self) -> Result<tungstenite::Message, Error> {
152 let mut reader = self.reader.lock().await;
153 let reader = reader.deref_mut();
154
155 let mut receive = reader.receiver.next();
156
157 let next = loop {
158 match futures::future::select(pin!(reader.heartbeat.tick()), pin!(&mut receive)).await {
159 futures::future::Either::Left((_instant, _next)) => {
160 self.check_and_send_heartbeat().await?;
161 }
162 futures::future::Either::Right((next, _)) => break next,
163 };
164 };
165
166 next.ok_or(Error::Disconnected)?.map_err(Error::Recv)
167 }
168
169 async fn check_and_send_heartbeat(&self) -> Result<(), Error> {
170 let val = self
171 .sent
172 .compare_exchange(true, false, Ordering::SeqCst, Ordering::Acquire);
173
174 match val {
175 Ok(val) => {
176 debug_assert!(val);
177 }
178 Err(val) => {
179 debug_assert!(!val);
180
181 let heartbeat = ChannelMsg {
182 join_reference: None,
183 message_reference: Cow::Owned(self.next_id().to_string()),
184 topic_name: Cow::Borrowed("phoenix"),
185 event_name: Cow::Borrowed("heartbeat"),
186 payload: Map::default(),
187 };
188
189 self.write_msg(heartbeat).await?;
190 }
191 }
192
193 Ok(())
194 }
195}