glimesh/conn/ws/
socket.rs

1use super::config::Config;
2use crate::{
3    entities::ws::{
4        DocumentSubscribeResponse, Empty, EventSubscription, PhxReply, ReceivePhoenixMessage,
5        SendPhoenixMessage,
6    },
7    Auth, Subscription, WebsocketConnectionError,
8};
9use async_tungstenite::{tokio::connect_async, tungstenite::Message};
10use backoff::{backoff::Backoff, ExponentialBackoff};
11use futures::{future::BoxFuture, FutureExt, SinkExt, Stream};
12use serde::{de::DeserializeOwned, Serialize};
13use serde_json::{json, Value};
14use std::{collections::HashMap, fmt::Debug, time::Duration};
15use tokio::{
16    select,
17    sync::{broadcast, mpsc},
18    task,
19    time::{sleep, timeout},
20};
21use tokio_stream::{
22    wrappers::{BroadcastStream, ReceiverStream},
23    StreamExt,
24};
25use tokio_util::sync::CancellationToken;
26use tracing::Instrument;
27use uuid::Uuid;
28
29#[derive(Debug)]
30enum SubOp {
31    AddSubscription(String, SubscriptionRef),
32    RemoveSubscription(String),
33}
34
35type CloseState = (
36    mpsc::Receiver<Message>,
37    mpsc::UnboundedReceiver<SubOp>,
38    HashMap<String, SubscriptionRef>,
39);
40
41pub(super) struct Socket {
42    auth: Auth,
43    config: Config,
44    join_ref: Uuid,
45    outgoing_messages: (mpsc::Sender<Message>, Option<mpsc::Receiver<Message>>),
46    incoming_messages: (
47        broadcast::Sender<ReceivePhoenixMessage<Value>>,
48        broadcast::Receiver<ReceivePhoenixMessage<Value>>,
49    ),
50    subscriptions: Option<HashMap<String, SubscriptionRef>>,
51    sub_ops: (
52        mpsc::UnboundedSender<SubOp>,
53        Option<mpsc::UnboundedReceiver<SubOp>>,
54    ),
55    cancellation_token: CancellationToken,
56    handle: Option<BoxFuture<'static, Result<CloseState, WebsocketConnectionError>>>,
57}
58
59impl Debug for Socket {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("Socket")
62            .field("auth", &self.auth)
63            .field("config", &self.config)
64            .field("join_ref", &self.join_ref)
65            .finish()
66    }
67}
68
69impl Socket {
70    pub fn new(auth: Auth, config: Config) -> Self {
71        let (outgoing_messages_sender, outgoing_messages_receiver) =
72            mpsc::channel(config.outgoing_capacity);
73        let incoming_messages = broadcast::channel(config.incoming_capacity);
74        let (sub_ops_sender, sub_ops_receiver) = mpsc::unbounded_channel();
75
76        Self {
77            auth,
78            config,
79            join_ref: Uuid::new_v4(),
80            outgoing_messages: (outgoing_messages_sender, Some(outgoing_messages_receiver)),
81            incoming_messages,
82            subscriptions: Some(Default::default()),
83            sub_ops: (sub_ops_sender, Some(sub_ops_receiver)),
84            cancellation_token: CancellationToken::new(),
85            handle: None,
86        }
87    }
88
89    pub fn client(&self) -> SocketClient {
90        SocketClient {
91            join_ref: self.join_ref,
92            outgoing_messages: self.outgoing_messages.0.clone(),
93            incoming_messages: self.incoming_messages.0.clone(),
94            sub_ops: self.sub_ops.0.clone(),
95            request_timeout: self.config.request_timeout,
96            cancellation_token: self.cancellation_token.clone(),
97        }
98    }
99
100    pub async fn connect(&mut self) -> Result<(), WebsocketConnectionError> {
101        let mut query = vec![("vsn", self.config.version.clone())];
102        match &self.auth {
103            Auth::ClientId(client_id) => query.push(("client_id", client_id.clone())),
104            Auth::AccessToken(token) => query.push(("token", token.clone())),
105            Auth::RefreshableAccessToken(token) => {
106                let access_token = token.access_token().await?;
107                query.push(("token", access_token.access_token));
108            }
109            Auth::ClientCredentials(client_credentials) => {
110                let access_token = client_credentials.access_token().await?;
111                query.push(("token", access_token.access_token));
112            }
113        }
114
115        let query_str = serde_urlencoded::to_string(query.as_slice())?;
116        let connection_url = format!("{}?{}", self.config.api_url, query_str);
117
118        let (ws_stream, _) = connect_async(&connection_url).await?;
119        let (mut ws_tx, mut ws_rx) = futures::StreamExt::split(ws_stream);
120
121        let cancellation_token = self.cancellation_token.child_token();
122
123        let outgoing_messages_handle = {
124            let mut outgoing_messages_receiver = self
125                .outgoing_messages
126                .1
127                .take()
128                .ok_or(WebsocketConnectionError::AlreadyConnected)?;
129            let cancellation_token = cancellation_token.clone();
130            task::spawn(async move {
131                loop {
132                    select! {
133                        _ = cancellation_token.cancelled() => {
134                            tracing::trace!("received cancellation signal");
135                            break;
136                        }
137                        msg = outgoing_messages_receiver.recv() => {
138                            match msg {
139                                Some(msg) => {
140                                    tracing::trace!(?msg, "sending message");
141                                    if let Err(err) = ws_tx.send(msg).await {
142                                        tracing::error!(?err, "failed to send message on the socket");
143                                        cancellation_token.cancel();
144                                        break;
145                                    }
146                                }
147                                None => {
148                                    tracing::trace!("all senders were dropped");
149                                    cancellation_token.cancel();
150                                    break;
151                                }
152                            }
153                        }
154                    }
155                }
156
157                outgoing_messages_receiver
158            })
159            .instrument(tracing::trace_span!("outgoing_messages"))
160        };
161
162        let incoming_messages_handle = {
163            let cancellation_token = cancellation_token.clone();
164            let incoming_messages_sender = self.incoming_messages.0.clone();
165            task::spawn(async move {
166                loop {
167                    select! {
168                        _ = cancellation_token.cancelled() => {
169                            tracing::trace!("received cancellation signal");
170                            break;
171                        }
172                        msg = ws_rx.next() => {
173                            match msg {
174                                Some(Ok(Message::Text(text))) => {
175                                    match serde_json::from_str::<ReceivePhoenixMessage<Value>>(&text) {
176                                        Ok(msg) => {
177                                            if msg.event == "phx_error" {
178                                                tracing::error!(?msg.payload, "error on socket");
179                                                cancellation_token.cancel();
180                                                break;
181                                            }
182
183                                            tracing::trace!(?msg, "incoming message");
184                                            if let Err(err) = incoming_messages_sender.send(msg) {
185                                                tracing::error!(?text, ?err, "failed to broadcast incoming message");
186                                            }
187                                        }
188                                        Err(err) => {
189                                            tracing::error!(?text, ?err, "failed to deserialize glimesh message");
190                                        }
191                                    }
192                                }
193                                Some(Ok(Message::Close(reason))) => {
194                                    tracing::error!(?reason, "socket closed");
195                                    cancellation_token.cancel();
196                                    break;
197                                }
198                                Some(Ok(frame)) => {
199                                    tracing::error!(?frame, "unexpected frame type");
200                                    cancellation_token.cancel();
201                                    break;
202                                }
203                                Some(Err(err)) => {
204                                    tracing::error!(?err, "socket error");
205                                    cancellation_token.cancel();
206                                    break;
207                                }
208                                None => {
209                                    // The socket must have errored in the previous
210                                    // iteration so we should never really get here
211                                    tracing::error!("no more socket messages");
212                                    cancellation_token.cancel();
213                                    break;
214                                }
215                            }
216                        }
217                    }
218                }
219            })
220            .instrument(tracing::trace_span!("incoming_messages"))
221        };
222
223        let socket_client = self.client();
224        if let Err(err) = socket_client
225            .request::<_, Empty>("__absinthe__:control".into(), "phx_join".into(), Empty {})
226            .await
227        {
228            tracing::error!(?err, "join request failed");
229            cancellation_token.cancel();
230            return Err(err);
231        }
232
233        let pinger_handle = {
234            let ping_interval = Duration::from_secs(30);
235            let cancellation_token = cancellation_token.clone();
236            task::spawn(async move {
237                loop {
238                    select! {
239                        _ = cancellation_token.cancelled() => {
240                            tracing::trace!("received cancellation signal");
241                            break;
242                        }
243                        _ = sleep(ping_interval) => {
244                            if let Err(err) = socket_client.request::<_, Empty>(
245                                "phoenix".into(),
246                                "heartbeat".into(),
247                                Empty {},
248                            )
249                            .await {
250                                tracing::error!(?err, "failed to send ping");
251                                cancellation_token.cancel();
252                                break;
253                            }
254                        }
255                    };
256                }
257            })
258            .instrument(tracing::trace_span!("pinger"))
259        };
260
261        let subscriptions_handle = {
262            let socket_client = self.client();
263            let mut sub_ops_receiver = self
264                .sub_ops
265                .1
266                .take()
267                .ok_or(WebsocketConnectionError::AlreadyConnected)?;
268            let mut subscriptions = self
269                .subscriptions
270                .take()
271                .ok_or(WebsocketConnectionError::AlreadyConnected)?;
272            task::spawn(async move {
273                let sub_ids = subscriptions.keys().cloned().collect::<Vec<_>>();
274                for old_sub_id in sub_ids {
275                    // break out early if we've been told to cancel
276                    if cancellation_token.is_cancelled() {
277                        break;
278                    }
279
280                    let sub = subscriptions.remove(&old_sub_id).unwrap();
281                    let op = || async {
282                        let res = socket_client
283                            .request::<_, DocumentSubscribeResponse>(
284                                "__absinthe__:control".into(),
285                                "doc".into(),
286                                &sub.payload,
287                            )
288                            .await;
289
290                        match res {
291                            Ok(subscription) => {
292                                Ok(subscription.response.subscription_id)
293                            }
294                            Err(err) => {
295                                tracing::debug!(?err, ?sub, "failed to resubscribe");
296
297                                if cancellation_token.is_cancelled() {
298                                    Err(backoff::Error::permanent(err))
299                                } else {
300                                    Err(backoff::Error::transient(err))
301                                }
302                            }
303                        }
304                    };
305
306                    match backoff::future::retry(ExponentialBackoff::default(), op).await {
307                        Ok(sub_id) => {
308                            tracing::debug!(?sub, "resubscribed");
309                            subscriptions.insert(sub_id, sub);
310                        }
311                        Err(err) => {
312                            tracing::error!(?err, "fatal error trying to resubscribe to subscriptions (did the socket die?)");
313                            // add the old sub back so we can retry on reconnect
314                            subscriptions.insert(old_sub_id, sub);
315                            // break of the loop, the socket is dead, lets reconnect
316                            break;
317                        }
318                    }
319                }
320
321                if !cancellation_token.is_cancelled() {
322                    let mut messages = socket_client.filter_messages::<EventSubscription, _>(|msg| {
323                        msg.event == "subscription:data" && msg.topic.starts_with("__absinthe__:doc")
324                    });
325
326                    loop {
327                        select! {
328                            _ = cancellation_token.cancelled() => {
329                                tracing::trace!("received cancellation signal");
330                                break;
331                            }
332                            sub = sub_ops_receiver.recv() => {
333                                match sub {
334                                    Some(SubOp::AddSubscription(sub_id, sub)) => {
335                                        subscriptions.insert(sub_id, sub);
336                                    }
337                                    Some(SubOp::RemoveSubscription(sub_id)) => {
338                                        subscriptions.remove(&sub_id);
339                                        let socket_client = socket_client.clone();
340                                        task::spawn(async move {
341                                            let payload = json!({ "subscriptionId": sub_id });
342                                            if let Err(err) = socket_client.send_message(
343                                                "__absinthe__:control".into(),
344                                                "unsubscribe".into(),
345                                                payload
346                                            ).await {
347                                                tracing::error!(?err, "failed to send unsubscribe request");
348                                            }
349                                        });
350                                    }
351                                    None => {
352                                        tracing::trace!("all senders were dropped");
353                                        cancellation_token.cancel();
354                                        break;
355                                    }
356                                }
357                            }
358                            msg = messages.next() => {
359                                match msg {
360                                    Some(EventSubscription{ result, subscription_id }) => {
361                                        if let Some(subscriber) = subscriptions.get(&subscription_id) {
362                                            match serde_json::from_value::<graphql_client::Response<Value>>(result) {
363                                                Ok(msg) => {
364                                                    if let Err(err) = subscriber.sender.send(msg).await {
365                                                        tracing::error!(?err, "failed to notify subscriber of event");
366                                                    }
367                                                }
368                                                Err(err) => {
369                                                    tracing::error!(?err, "invalid subscription message received");
370                                                }
371                                            }
372                                        }
373                                    }
374                                    None => {
375                                        tracing::trace!("all senders were dropped");
376                                        cancellation_token.cancel();
377                                        break;
378                                    }
379                                }
380                            }
381                        }
382                    }
383                }
384
385                (sub_ops_receiver, subscriptions)
386            })
387            .instrument(tracing::trace_span!("subscriptions"))
388        };
389
390        self.handle.replace(
391            async move {
392                incoming_messages_handle
393                    .await
394                    .map_err(anyhow::Error::from)?;
395                pinger_handle.await.map_err(anyhow::Error::from)?;
396                let outgoing_messages_receiver = outgoing_messages_handle
397                    .await
398                    .map_err(anyhow::Error::from)?;
399                let (sub_ops_receiver, subscriptions) =
400                    subscriptions_handle.await.map_err(anyhow::Error::from)?;
401                Ok::<_, WebsocketConnectionError>((
402                    outgoing_messages_receiver,
403                    sub_ops_receiver,
404                    subscriptions,
405                ))
406            }
407            .boxed(),
408        );
409
410        tracing::debug!("connected to socket");
411
412        Ok(())
413    }
414
415    pub fn stay_conected(mut self) {
416        task::spawn(async move {
417            loop {
418                if let Err(err) = self.wait().await {
419                    tracing::error!(?err, "irrecoverable connecton error");
420                    // TODO: some way of bubbling this up to the consumer
421                    break;
422                }
423
424                if self.cancellation_token.is_cancelled() {
425                    break;
426                }
427
428                let mut backoff = ExponentialBackoff::default();
429                while let Err(err) = self.connect().await {
430                    match backoff.next_backoff() {
431                        Some(backoff_time) => {
432                            tracing::error!(
433                                ?err,
434                                "failed to reconnect, retrying in {:?}",
435                                backoff_time
436                            );
437                            sleep(backoff_time).await;
438                        }
439                        None => {
440                            tracing::error!(?err, "failed to reconnect, after many attempts");
441                            // TODO: some way of bubbling this up to the consumer
442                            return;
443                        }
444                    }
445                }
446
447                tracing::info!("successfully reconnected")
448            }
449        });
450    }
451
452    async fn wait(&mut self) -> Result<(), WebsocketConnectionError> {
453        let handle = self
454            .handle
455            .take()
456            .ok_or(WebsocketConnectionError::SocketClosed)?;
457        let (outgoing_messages_receiver, sub_ops_receiver, subscriptions) = handle.await?;
458        self.outgoing_messages.1.replace(outgoing_messages_receiver);
459        self.sub_ops.1.replace(sub_ops_receiver);
460        self.subscriptions.replace(subscriptions);
461        Ok(())
462    }
463}
464
465#[derive(Debug)]
466struct SubscriptionRef {
467    payload: Value,
468    sender: mpsc::Sender<graphql_client::Response<Value>>,
469}
470
471#[derive(Debug, Clone)]
472pub(super) struct SocketClient {
473    join_ref: Uuid,
474    outgoing_messages: mpsc::Sender<Message>,
475    incoming_messages: broadcast::Sender<ReceivePhoenixMessage<Value>>,
476    sub_ops: mpsc::UnboundedSender<SubOp>,
477    request_timeout: Duration,
478    cancellation_token: CancellationToken,
479}
480
481impl SocketClient {
482    pub async fn send_message<T>(
483        &self,
484        topic: String,
485        event: String,
486        payload: T,
487    ) -> Result<Uuid, WebsocketConnectionError>
488    where
489        T: Serialize,
490    {
491        let msg_ref = Uuid::new_v4();
492        let msg = serde_json::to_string(&SendPhoenixMessage {
493            join_ref: self.join_ref,
494            msg_ref,
495            topic,
496            event,
497            payload,
498        })?;
499        self.outgoing_messages.send(msg.into()).await?;
500        Ok(msg_ref)
501    }
502
503    pub async fn request<T, U>(
504        &self,
505        topic: String,
506        event: String,
507        payload: T,
508    ) -> Result<PhxReply<U>, WebsocketConnectionError>
509    where
510        T: Serialize,
511        U: DeserializeOwned,
512    {
513        let msg_ref = self.send_message(topic, event, payload).await?;
514        timeout(
515            self.request_timeout,
516            self.filter_messages::<PhxReply<U>, _>(move |msg| msg.msg_ref == Some(msg_ref))
517                .take(1)
518                .next(),
519        )
520        .await?
521        .ok_or(WebsocketConnectionError::SocketClosed)
522    }
523
524    pub async fn subscribe<T, U>(
525        &self,
526        payload: T,
527    ) -> Result<Subscription<U>, WebsocketConnectionError>
528    where
529        T: Serialize,
530        U: DeserializeOwned,
531    {
532        let subscription: PhxReply<DocumentSubscribeResponse> = self
533            .request("__absinthe__:control".into(), "doc".into(), &payload)
534            .await?;
535        let payload = serde_json::to_value(&payload)?;
536
537        let (sender, receiver) = mpsc::channel(10);
538
539        let sub_id = subscription.response.subscription_id;
540        self.sub_ops
541            .send(SubOp::AddSubscription(
542                sub_id.clone(),
543                SubscriptionRef { payload, sender },
544            ))
545            .map_err(anyhow::Error::from)?;
546
547        let this = self.clone();
548        Ok(Subscription::wrap(
549            ReceiverStream::new(receiver).filter_map(|res| serde_json::from_value(res.data?).ok()),
550            Some(move || {
551                if let Err(err) = this.sub_ops.send(SubOp::RemoveSubscription(sub_id)) {
552                    tracing::error!(?err, "failed to notify unsubscribe");
553                }
554            }),
555        ))
556    }
557
558    pub fn filter_messages<T, F>(&self, mut predicate: F) -> impl Stream<Item = T>
559    where
560        T: DeserializeOwned,
561        F: FnMut(&ReceivePhoenixMessage<Value>) -> bool,
562    {
563        BroadcastStream::new(self.incoming_messages.subscribe()).filter_map(move |msg| match msg {
564            Ok(msg) => {
565                if predicate(&msg) {
566                    serde_json::from_value::<T>(msg.payload).ok()
567                } else {
568                    None
569                }
570            }
571            Err(_) => None,
572        })
573    }
574
575    pub fn close(self) {
576        self.cancellation_token.cancel();
577    }
578}