1use super::config::Config;
2use crate::{
3 entities::ws::{
4 DocumentSubscribeResponse, Empty, EventSubscription, PhxReply, ReceivePhoenixMessage,
5 SendPhoenixMessage,
6 },
7 Auth, Subscription, WebsocketConnectionError,
8};
9use async_tungstenite::{tokio::connect_async, tungstenite::Message};
10use backoff::{backoff::Backoff, ExponentialBackoff};
11use futures::{future::BoxFuture, FutureExt, SinkExt, Stream};
12use serde::{de::DeserializeOwned, Serialize};
13use serde_json::{json, Value};
14use std::{collections::HashMap, fmt::Debug, time::Duration};
15use tokio::{
16 select,
17 sync::{broadcast, mpsc},
18 task,
19 time::{sleep, timeout},
20};
21use tokio_stream::{
22 wrappers::{BroadcastStream, ReceiverStream},
23 StreamExt,
24};
25use tokio_util::sync::CancellationToken;
26use tracing::Instrument;
27use uuid::Uuid;
28
29#[derive(Debug)]
30enum SubOp {
31 AddSubscription(String, SubscriptionRef),
32 RemoveSubscription(String),
33}
34
35type CloseState = (
36 mpsc::Receiver<Message>,
37 mpsc::UnboundedReceiver<SubOp>,
38 HashMap<String, SubscriptionRef>,
39);
40
41pub(super) struct Socket {
42 auth: Auth,
43 config: Config,
44 join_ref: Uuid,
45 outgoing_messages: (mpsc::Sender<Message>, Option<mpsc::Receiver<Message>>),
46 incoming_messages: (
47 broadcast::Sender<ReceivePhoenixMessage<Value>>,
48 broadcast::Receiver<ReceivePhoenixMessage<Value>>,
49 ),
50 subscriptions: Option<HashMap<String, SubscriptionRef>>,
51 sub_ops: (
52 mpsc::UnboundedSender<SubOp>,
53 Option<mpsc::UnboundedReceiver<SubOp>>,
54 ),
55 cancellation_token: CancellationToken,
56 handle: Option<BoxFuture<'static, Result<CloseState, WebsocketConnectionError>>>,
57}
58
59impl Debug for Socket {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("Socket")
62 .field("auth", &self.auth)
63 .field("config", &self.config)
64 .field("join_ref", &self.join_ref)
65 .finish()
66 }
67}
68
69impl Socket {
70 pub fn new(auth: Auth, config: Config) -> Self {
71 let (outgoing_messages_sender, outgoing_messages_receiver) =
72 mpsc::channel(config.outgoing_capacity);
73 let incoming_messages = broadcast::channel(config.incoming_capacity);
74 let (sub_ops_sender, sub_ops_receiver) = mpsc::unbounded_channel();
75
76 Self {
77 auth,
78 config,
79 join_ref: Uuid::new_v4(),
80 outgoing_messages: (outgoing_messages_sender, Some(outgoing_messages_receiver)),
81 incoming_messages,
82 subscriptions: Some(Default::default()),
83 sub_ops: (sub_ops_sender, Some(sub_ops_receiver)),
84 cancellation_token: CancellationToken::new(),
85 handle: None,
86 }
87 }
88
89 pub fn client(&self) -> SocketClient {
90 SocketClient {
91 join_ref: self.join_ref,
92 outgoing_messages: self.outgoing_messages.0.clone(),
93 incoming_messages: self.incoming_messages.0.clone(),
94 sub_ops: self.sub_ops.0.clone(),
95 request_timeout: self.config.request_timeout,
96 cancellation_token: self.cancellation_token.clone(),
97 }
98 }
99
100 pub async fn connect(&mut self) -> Result<(), WebsocketConnectionError> {
101 let mut query = vec![("vsn", self.config.version.clone())];
102 match &self.auth {
103 Auth::ClientId(client_id) => query.push(("client_id", client_id.clone())),
104 Auth::AccessToken(token) => query.push(("token", token.clone())),
105 Auth::RefreshableAccessToken(token) => {
106 let access_token = token.access_token().await?;
107 query.push(("token", access_token.access_token));
108 }
109 Auth::ClientCredentials(client_credentials) => {
110 let access_token = client_credentials.access_token().await?;
111 query.push(("token", access_token.access_token));
112 }
113 }
114
115 let query_str = serde_urlencoded::to_string(query.as_slice())?;
116 let connection_url = format!("{}?{}", self.config.api_url, query_str);
117
118 let (ws_stream, _) = connect_async(&connection_url).await?;
119 let (mut ws_tx, mut ws_rx) = futures::StreamExt::split(ws_stream);
120
121 let cancellation_token = self.cancellation_token.child_token();
122
123 let outgoing_messages_handle = {
124 let mut outgoing_messages_receiver = self
125 .outgoing_messages
126 .1
127 .take()
128 .ok_or(WebsocketConnectionError::AlreadyConnected)?;
129 let cancellation_token = cancellation_token.clone();
130 task::spawn(async move {
131 loop {
132 select! {
133 _ = cancellation_token.cancelled() => {
134 tracing::trace!("received cancellation signal");
135 break;
136 }
137 msg = outgoing_messages_receiver.recv() => {
138 match msg {
139 Some(msg) => {
140 tracing::trace!(?msg, "sending message");
141 if let Err(err) = ws_tx.send(msg).await {
142 tracing::error!(?err, "failed to send message on the socket");
143 cancellation_token.cancel();
144 break;
145 }
146 }
147 None => {
148 tracing::trace!("all senders were dropped");
149 cancellation_token.cancel();
150 break;
151 }
152 }
153 }
154 }
155 }
156
157 outgoing_messages_receiver
158 })
159 .instrument(tracing::trace_span!("outgoing_messages"))
160 };
161
162 let incoming_messages_handle = {
163 let cancellation_token = cancellation_token.clone();
164 let incoming_messages_sender = self.incoming_messages.0.clone();
165 task::spawn(async move {
166 loop {
167 select! {
168 _ = cancellation_token.cancelled() => {
169 tracing::trace!("received cancellation signal");
170 break;
171 }
172 msg = ws_rx.next() => {
173 match msg {
174 Some(Ok(Message::Text(text))) => {
175 match serde_json::from_str::<ReceivePhoenixMessage<Value>>(&text) {
176 Ok(msg) => {
177 if msg.event == "phx_error" {
178 tracing::error!(?msg.payload, "error on socket");
179 cancellation_token.cancel();
180 break;
181 }
182
183 tracing::trace!(?msg, "incoming message");
184 if let Err(err) = incoming_messages_sender.send(msg) {
185 tracing::error!(?text, ?err, "failed to broadcast incoming message");
186 }
187 }
188 Err(err) => {
189 tracing::error!(?text, ?err, "failed to deserialize glimesh message");
190 }
191 }
192 }
193 Some(Ok(Message::Close(reason))) => {
194 tracing::error!(?reason, "socket closed");
195 cancellation_token.cancel();
196 break;
197 }
198 Some(Ok(frame)) => {
199 tracing::error!(?frame, "unexpected frame type");
200 cancellation_token.cancel();
201 break;
202 }
203 Some(Err(err)) => {
204 tracing::error!(?err, "socket error");
205 cancellation_token.cancel();
206 break;
207 }
208 None => {
209 tracing::error!("no more socket messages");
212 cancellation_token.cancel();
213 break;
214 }
215 }
216 }
217 }
218 }
219 })
220 .instrument(tracing::trace_span!("incoming_messages"))
221 };
222
223 let socket_client = self.client();
224 if let Err(err) = socket_client
225 .request::<_, Empty>("__absinthe__:control".into(), "phx_join".into(), Empty {})
226 .await
227 {
228 tracing::error!(?err, "join request failed");
229 cancellation_token.cancel();
230 return Err(err);
231 }
232
233 let pinger_handle = {
234 let ping_interval = Duration::from_secs(30);
235 let cancellation_token = cancellation_token.clone();
236 task::spawn(async move {
237 loop {
238 select! {
239 _ = cancellation_token.cancelled() => {
240 tracing::trace!("received cancellation signal");
241 break;
242 }
243 _ = sleep(ping_interval) => {
244 if let Err(err) = socket_client.request::<_, Empty>(
245 "phoenix".into(),
246 "heartbeat".into(),
247 Empty {},
248 )
249 .await {
250 tracing::error!(?err, "failed to send ping");
251 cancellation_token.cancel();
252 break;
253 }
254 }
255 };
256 }
257 })
258 .instrument(tracing::trace_span!("pinger"))
259 };
260
261 let subscriptions_handle = {
262 let socket_client = self.client();
263 let mut sub_ops_receiver = self
264 .sub_ops
265 .1
266 .take()
267 .ok_or(WebsocketConnectionError::AlreadyConnected)?;
268 let mut subscriptions = self
269 .subscriptions
270 .take()
271 .ok_or(WebsocketConnectionError::AlreadyConnected)?;
272 task::spawn(async move {
273 let sub_ids = subscriptions.keys().cloned().collect::<Vec<_>>();
274 for old_sub_id in sub_ids {
275 if cancellation_token.is_cancelled() {
277 break;
278 }
279
280 let sub = subscriptions.remove(&old_sub_id).unwrap();
281 let op = || async {
282 let res = socket_client
283 .request::<_, DocumentSubscribeResponse>(
284 "__absinthe__:control".into(),
285 "doc".into(),
286 &sub.payload,
287 )
288 .await;
289
290 match res {
291 Ok(subscription) => {
292 Ok(subscription.response.subscription_id)
293 }
294 Err(err) => {
295 tracing::debug!(?err, ?sub, "failed to resubscribe");
296
297 if cancellation_token.is_cancelled() {
298 Err(backoff::Error::permanent(err))
299 } else {
300 Err(backoff::Error::transient(err))
301 }
302 }
303 }
304 };
305
306 match backoff::future::retry(ExponentialBackoff::default(), op).await {
307 Ok(sub_id) => {
308 tracing::debug!(?sub, "resubscribed");
309 subscriptions.insert(sub_id, sub);
310 }
311 Err(err) => {
312 tracing::error!(?err, "fatal error trying to resubscribe to subscriptions (did the socket die?)");
313 subscriptions.insert(old_sub_id, sub);
315 break;
317 }
318 }
319 }
320
321 if !cancellation_token.is_cancelled() {
322 let mut messages = socket_client.filter_messages::<EventSubscription, _>(|msg| {
323 msg.event == "subscription:data" && msg.topic.starts_with("__absinthe__:doc")
324 });
325
326 loop {
327 select! {
328 _ = cancellation_token.cancelled() => {
329 tracing::trace!("received cancellation signal");
330 break;
331 }
332 sub = sub_ops_receiver.recv() => {
333 match sub {
334 Some(SubOp::AddSubscription(sub_id, sub)) => {
335 subscriptions.insert(sub_id, sub);
336 }
337 Some(SubOp::RemoveSubscription(sub_id)) => {
338 subscriptions.remove(&sub_id);
339 let socket_client = socket_client.clone();
340 task::spawn(async move {
341 let payload = json!({ "subscriptionId": sub_id });
342 if let Err(err) = socket_client.send_message(
343 "__absinthe__:control".into(),
344 "unsubscribe".into(),
345 payload
346 ).await {
347 tracing::error!(?err, "failed to send unsubscribe request");
348 }
349 });
350 }
351 None => {
352 tracing::trace!("all senders were dropped");
353 cancellation_token.cancel();
354 break;
355 }
356 }
357 }
358 msg = messages.next() => {
359 match msg {
360 Some(EventSubscription{ result, subscription_id }) => {
361 if let Some(subscriber) = subscriptions.get(&subscription_id) {
362 match serde_json::from_value::<graphql_client::Response<Value>>(result) {
363 Ok(msg) => {
364 if let Err(err) = subscriber.sender.send(msg).await {
365 tracing::error!(?err, "failed to notify subscriber of event");
366 }
367 }
368 Err(err) => {
369 tracing::error!(?err, "invalid subscription message received");
370 }
371 }
372 }
373 }
374 None => {
375 tracing::trace!("all senders were dropped");
376 cancellation_token.cancel();
377 break;
378 }
379 }
380 }
381 }
382 }
383 }
384
385 (sub_ops_receiver, subscriptions)
386 })
387 .instrument(tracing::trace_span!("subscriptions"))
388 };
389
390 self.handle.replace(
391 async move {
392 incoming_messages_handle
393 .await
394 .map_err(anyhow::Error::from)?;
395 pinger_handle.await.map_err(anyhow::Error::from)?;
396 let outgoing_messages_receiver = outgoing_messages_handle
397 .await
398 .map_err(anyhow::Error::from)?;
399 let (sub_ops_receiver, subscriptions) =
400 subscriptions_handle.await.map_err(anyhow::Error::from)?;
401 Ok::<_, WebsocketConnectionError>((
402 outgoing_messages_receiver,
403 sub_ops_receiver,
404 subscriptions,
405 ))
406 }
407 .boxed(),
408 );
409
410 tracing::debug!("connected to socket");
411
412 Ok(())
413 }
414
415 pub fn stay_conected(mut self) {
416 task::spawn(async move {
417 loop {
418 if let Err(err) = self.wait().await {
419 tracing::error!(?err, "irrecoverable connecton error");
420 break;
422 }
423
424 if self.cancellation_token.is_cancelled() {
425 break;
426 }
427
428 let mut backoff = ExponentialBackoff::default();
429 while let Err(err) = self.connect().await {
430 match backoff.next_backoff() {
431 Some(backoff_time) => {
432 tracing::error!(
433 ?err,
434 "failed to reconnect, retrying in {:?}",
435 backoff_time
436 );
437 sleep(backoff_time).await;
438 }
439 None => {
440 tracing::error!(?err, "failed to reconnect, after many attempts");
441 return;
443 }
444 }
445 }
446
447 tracing::info!("successfully reconnected")
448 }
449 });
450 }
451
452 async fn wait(&mut self) -> Result<(), WebsocketConnectionError> {
453 let handle = self
454 .handle
455 .take()
456 .ok_or(WebsocketConnectionError::SocketClosed)?;
457 let (outgoing_messages_receiver, sub_ops_receiver, subscriptions) = handle.await?;
458 self.outgoing_messages.1.replace(outgoing_messages_receiver);
459 self.sub_ops.1.replace(sub_ops_receiver);
460 self.subscriptions.replace(subscriptions);
461 Ok(())
462 }
463}
464
465#[derive(Debug)]
466struct SubscriptionRef {
467 payload: Value,
468 sender: mpsc::Sender<graphql_client::Response<Value>>,
469}
470
471#[derive(Debug, Clone)]
472pub(super) struct SocketClient {
473 join_ref: Uuid,
474 outgoing_messages: mpsc::Sender<Message>,
475 incoming_messages: broadcast::Sender<ReceivePhoenixMessage<Value>>,
476 sub_ops: mpsc::UnboundedSender<SubOp>,
477 request_timeout: Duration,
478 cancellation_token: CancellationToken,
479}
480
481impl SocketClient {
482 pub async fn send_message<T>(
483 &self,
484 topic: String,
485 event: String,
486 payload: T,
487 ) -> Result<Uuid, WebsocketConnectionError>
488 where
489 T: Serialize,
490 {
491 let msg_ref = Uuid::new_v4();
492 let msg = serde_json::to_string(&SendPhoenixMessage {
493 join_ref: self.join_ref,
494 msg_ref,
495 topic,
496 event,
497 payload,
498 })?;
499 self.outgoing_messages.send(msg.into()).await?;
500 Ok(msg_ref)
501 }
502
503 pub async fn request<T, U>(
504 &self,
505 topic: String,
506 event: String,
507 payload: T,
508 ) -> Result<PhxReply<U>, WebsocketConnectionError>
509 where
510 T: Serialize,
511 U: DeserializeOwned,
512 {
513 let msg_ref = self.send_message(topic, event, payload).await?;
514 timeout(
515 self.request_timeout,
516 self.filter_messages::<PhxReply<U>, _>(move |msg| msg.msg_ref == Some(msg_ref))
517 .take(1)
518 .next(),
519 )
520 .await?
521 .ok_or(WebsocketConnectionError::SocketClosed)
522 }
523
524 pub async fn subscribe<T, U>(
525 &self,
526 payload: T,
527 ) -> Result<Subscription<U>, WebsocketConnectionError>
528 where
529 T: Serialize,
530 U: DeserializeOwned,
531 {
532 let subscription: PhxReply<DocumentSubscribeResponse> = self
533 .request("__absinthe__:control".into(), "doc".into(), &payload)
534 .await?;
535 let payload = serde_json::to_value(&payload)?;
536
537 let (sender, receiver) = mpsc::channel(10);
538
539 let sub_id = subscription.response.subscription_id;
540 self.sub_ops
541 .send(SubOp::AddSubscription(
542 sub_id.clone(),
543 SubscriptionRef { payload, sender },
544 ))
545 .map_err(anyhow::Error::from)?;
546
547 let this = self.clone();
548 Ok(Subscription::wrap(
549 ReceiverStream::new(receiver).filter_map(|res| serde_json::from_value(res.data?).ok()),
550 Some(move || {
551 if let Err(err) = this.sub_ops.send(SubOp::RemoveSubscription(sub_id)) {
552 tracing::error!(?err, "failed to notify unsubscribe");
553 }
554 }),
555 ))
556 }
557
558 pub fn filter_messages<T, F>(&self, mut predicate: F) -> impl Stream<Item = T>
559 where
560 T: DeserializeOwned,
561 F: FnMut(&ReceivePhoenixMessage<Value>) -> bool,
562 {
563 BroadcastStream::new(self.incoming_messages.subscribe()).filter_map(move |msg| match msg {
564 Ok(msg) => {
565 if predicate(&msg) {
566 serde_json::from_value::<T>(msg.payload).ok()
567 } else {
568 None
569 }
570 }
571 Err(_) => None,
572 })
573 }
574
575 pub fn close(self) {
576 self.cancellation_token.cancel();
577 }
578}