apollo_router/plugins/subscription/
notification.rs

1//! Internal pub/sub facility for subscription
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::hash::Hash;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::sync::Weak;
8use std::task::Context;
9use std::task::Poll;
10use std::time::Duration;
11use std::time::Instant;
12
13use futures::Sink;
14use futures::Stream;
15use futures::StreamExt;
16use pin_project_lite::pin_project;
17use thiserror::Error;
18use tokio::sync::broadcast;
19use tokio::sync::mpsc;
20use tokio::sync::mpsc::error::SendError;
21use tokio::sync::mpsc::error::TrySendError;
22use tokio::sync::oneshot;
23use tokio::sync::oneshot::error::RecvError;
24use tokio_stream::wrappers::BroadcastStream;
25use tokio_stream::wrappers::IntervalStream;
26use tokio_stream::wrappers::ReceiverStream;
27use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
28
29use crate::Configuration;
30use crate::graphql;
31use crate::metrics::FutureMetricsExt;
32use crate::metrics::UpDownCounterGuard;
33use crate::spec::Schema;
34
35static NOTIFY_CHANNEL_SIZE: usize = 1024;
36static DEFAULT_MSG_CHANNEL_SIZE: usize = 128;
37
38#[derive(Error, Debug)]
39pub(crate) enum NotifyError<K, V> {
40    #[error("cannot receive data from pubsub")]
41    RecvError(#[from] RecvError),
42    #[error("cannot send data to pubsub")]
43    SendError(#[from] SendError<V>),
44    #[error("cannot send data to pubsub")]
45    NotificationSendError(#[from] SendError<Notification<K, V>>),
46    #[error("cannot send data to pubsub")]
47    NotificationTrySendError(#[from] TrySendError<Notification<K, V>>),
48    #[error("cannot send data to response stream")]
49    BroadcastSendError(#[from] broadcast::error::SendError<V>),
50    #[error("this topic doesn't exist")]
51    UnknownTopic,
52}
53
54type ResponseSender<V> =
55    oneshot::Sender<Option<(broadcast::Sender<Option<V>>, broadcast::Receiver<Option<V>>)>>;
56
57pub(crate) struct CreatedTopicPayload<V> {
58    msg_sender: broadcast::Sender<Option<V>>,
59    msg_receiver: broadcast::Receiver<Option<V>>,
60    closing_signal: broadcast::Receiver<()>,
61    created: bool,
62}
63
64type ResponseSenderWithCreated<V> = oneshot::Sender<CreatedTopicPayload<V>>;
65
66pub(crate) enum Notification<K, V> {
67    CreateOrSubscribe {
68        topic: K,
69        // Sender connected to the original source stream
70        msg_sender: broadcast::Sender<Option<V>>,
71        // To know if it has been created or re-used
72        response_sender: ResponseSenderWithCreated<V>,
73        heartbeat_enabled: bool,
74        // Useful for the metric we create
75        operation_name: Option<String>,
76    },
77    Subscribe {
78        topic: K,
79        // Oneshot channel to fetch the receiver
80        response_sender: ResponseSender<V>,
81    },
82    SubscribeIfExist {
83        topic: K,
84        // Oneshot channel to fetch the receiver
85        response_sender: ResponseSender<V>,
86    },
87    Unsubscribe {
88        topic: K,
89    },
90    ForceDelete {
91        topic: K,
92    },
93    Exist {
94        topic: K,
95        response_sender: oneshot::Sender<bool>,
96    },
97    InvalidIds {
98        topics: Vec<K>,
99        response_sender: oneshot::Sender<(Vec<K>, Vec<K>)>,
100    },
101    UpdateHeartbeat {
102        new_ttl: Option<Duration>,
103    },
104    #[cfg(test)]
105    TryDelete {
106        topic: K,
107    },
108    #[cfg(test)]
109    Broadcast {
110        data: V,
111    },
112    #[cfg(test)]
113    Debug {
114        // Returns the number of subscriptions and subscribers
115        response_sender: oneshot::Sender<usize>,
116    },
117}
118
119impl<K, V> Debug for Notification<K, V> {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        match self {
122            Self::CreateOrSubscribe { .. } => f.debug_struct("CreateOrSubscribe").finish(),
123            Self::Subscribe { .. } => f.debug_struct("Subscribe").finish(),
124            Self::SubscribeIfExist { .. } => f.debug_struct("SubscribeIfExist").finish(),
125            Self::Unsubscribe { .. } => f.debug_struct("Unsubscribe").finish(),
126            Self::ForceDelete { .. } => f.debug_struct("ForceDelete").finish(),
127            Self::Exist { .. } => f.debug_struct("Exist").finish(),
128            Self::InvalidIds { .. } => f.debug_struct("InvalidIds").finish(),
129            Self::UpdateHeartbeat { .. } => f.debug_struct("UpdateHeartbeat").finish(),
130            #[cfg(test)]
131            Self::TryDelete { .. } => f.debug_struct("TryDelete").finish(),
132            #[cfg(test)]
133            Self::Broadcast { .. } => f.debug_struct("Broadcast").finish(),
134            #[cfg(test)]
135            Self::Debug { .. } => f.debug_struct("Debug").finish(),
136        }
137    }
138}
139
140/// In memory pub/sub implementation
141#[derive(Clone)]
142pub struct Notify<K, V> {
143    sender: mpsc::Sender<Notification<K, V>>,
144    /// Size (number of events) of the channel to receive message
145    pub(crate) queue_size: Option<usize>,
146    router_broadcasts: Arc<RouterBroadcasts>,
147}
148
149#[buildstructor::buildstructor]
150impl<K, V> Notify<K, V>
151where
152    K: Send + Hash + Eq + Clone + 'static,
153    V: Send + Sync + Clone + 'static,
154{
155    #[builder]
156    pub(crate) fn new(
157        ttl: Option<Duration>,
158        heartbeat_error_message: Option<V>,
159        queue_size: Option<usize>,
160    ) -> Notify<K, V> {
161        let (sender, receiver) = mpsc::channel(NOTIFY_CHANNEL_SIZE);
162        let receiver_stream: ReceiverStream<Notification<K, V>> = ReceiverStream::new(receiver);
163        tokio::task::spawn(
164            task(receiver_stream, ttl, heartbeat_error_message).with_current_meter_provider(),
165        );
166        Notify {
167            sender,
168            queue_size,
169            router_broadcasts: Arc::new(RouterBroadcasts::new()),
170        }
171    }
172
173    #[doc(hidden)]
174    /// NOOP notifier for tests
175    pub fn for_tests() -> Self {
176        let (sender, _receiver) = mpsc::channel(NOTIFY_CHANNEL_SIZE);
177        Notify {
178            sender,
179            queue_size: None,
180            router_broadcasts: Arc::new(RouterBroadcasts::new()),
181        }
182    }
183}
184
185impl<K, V> Notify<K, V> {
186    /// Broadcast a new configuration
187    pub(crate) fn broadcast_configuration(&self, configuration: Weak<Configuration>) {
188        self.router_broadcasts.configuration.0.send(configuration).expect("cannot send the configuration update to the static channel. Should not happen because the receiver will always live in this struct; qed");
189    }
190    /// Receive the new configuration everytime we have a new router configuration
191    pub(crate) fn subscribe_configuration(
192        &self,
193    ) -> impl Stream<Item = Weak<Configuration>> + use<K, V> {
194        self.router_broadcasts.subscribe_configuration()
195    }
196    /// Receive the new schema everytime we have a new schema
197    pub(crate) fn broadcast_schema(&self, schema: Arc<Schema>) {
198        self.router_broadcasts.schema.0.send(schema).expect("cannot send the schema update to the static channel. Should not happen because the receiver will always live in this struct; qed");
199    }
200    /// Receive the new schema everytime we have a new schema
201    pub(crate) fn subscribe_schema(&self) -> impl Stream<Item = Arc<Schema>> + use<K, V> {
202        self.router_broadcasts.subscribe_schema()
203    }
204}
205
206impl<K, V> Notify<K, V>
207where
208    K: Send + Hash + Eq + Clone + 'static,
209    V: Send + Clone + 'static,
210{
211    pub(crate) async fn set_ttl(&self, new_ttl: Option<Duration>) -> Result<(), NotifyError<K, V>> {
212        self.sender
213            .send(Notification::UpdateHeartbeat { new_ttl })
214            .await?;
215
216        Ok(())
217    }
218
219    /// Creates or subscribes to a topic, returning a handle and subscription state.
220    ///
221    /// The `Ok()` branch of the `Result` is a tuple, where:
222    ///     - .0: a `Handle` on the subscription event listener,
223    ///     - .1: a boolean, where
224    ///              - `true`: call to this fn `created` this subscription, and
225    ///              - `false`: call to this fn was for a deduplicated subscription
226    ///                         i.e. subscription already exists,
227    ///     - .2: a closing signal in a form of `broadcast::Receiver` that gets
228    ///            triggered once the subscription is closed.
229    ///
230    /// # Closing Signal Usage
231    ///
232    /// The closing signal's usage depends on how subscriptions are managed:
233    ///
234    /// ## Callback Mode (HTTP-based subscriptions)
235    /// - The closing signal is typically **unused** as there are no long-running
236    ///   forwarding tasks to clean up
237    /// - Subscriptions are managed via HTTP callbacks to a public URL
238    /// - Subscription lifecycle is controlled through HTTP responses (404 closes the subscription)
239    /// - Always called with `heartbeat_enabled = true` to enable TTL-based timeout checking
240    ///
241    /// ## Passthrough Mode (WebSocket-based subscriptions)  
242    /// - The closing signal **must be monitored** by the forwarding task using `tokio::select!`
243    /// - Maintains persistent WebSocket connections to subgraphs
244    /// - Needed for proper cleanup when subscriptions are terminated, especially important
245    ///   for deduplication (multiple clients may share one subgraph connection)
246    /// - Always called with `heartbeat_enabled = false` as WebSockets have their own
247    ///   connection management
248    ///
249    /// # Parameters
250    /// - `topic`: The subscription topic identifier
251    /// - `heartbeat_enabled`: Controls TTL-based timeout checking at the notification layer:
252    ///   - `true`: Enables TTL checking. For callback mode, subscriptions will timeout if
253    ///     no heartbeat is received within the TTL period. The actual heartbeat interval
254    ///     is configured separately and sent to subgraphs in the subscription extension.
255    ///     When subgraphs send heartbeat messages, they're processed via `invalid_ids()`
256    ///     which calls `touch()` to update the subscription's `updated_at` timestamp.
257    ///   - `false`: Disables TTL checking (used by passthrough/WebSocket mode)
258    /// - `operation_name`: Optional GraphQL operation name for metrics
259    ///
260    /// # Heartbeat Processing for Callback Mode
261    ///
262    /// When callback mode is configured with a heartbeat interval:
263    /// 1. The interval is converted to milliseconds and sent to the subgraph as
264    ///    `heartbeat_interval_ms` in the subscription extension
265    /// 2. Subgraphs send periodic heartbeat callbacks with subscription IDs
266    /// 3. The heartbeat handler validates IDs and calls `notify.invalid_ids()`
267    /// 4. This updates each valid subscription's timestamp via `touch()`
268    /// 5. The TTL checker uses these timestamps to determine if subscriptions are alive
269    ///    and closes those that haven't been touched within the TTL period
270    pub(crate) async fn create_or_subscribe(
271        &mut self,
272        topic: K,
273        heartbeat_enabled: bool,
274        operation_name: Option<String>,
275    ) -> Result<(Handle<K, V>, bool, broadcast::Receiver<()>), NotifyError<K, V>> {
276        let (sender, _receiver) =
277            broadcast::channel(self.queue_size.unwrap_or(DEFAULT_MSG_CHANNEL_SIZE));
278
279        let (tx, rx) = oneshot::channel();
280        self.sender
281            .send(Notification::CreateOrSubscribe {
282                topic: topic.clone(),
283                msg_sender: sender,
284                response_sender: tx,
285                heartbeat_enabled,
286                operation_name,
287            })
288            .await?;
289
290        let CreatedTopicPayload {
291            msg_sender,
292            msg_receiver,
293            closing_signal,
294            created,
295        } = rx.await?;
296        let handle = Handle::new(
297            topic,
298            self.sender.clone(),
299            msg_sender,
300            BroadcastStream::from(msg_receiver),
301        );
302
303        Ok((handle, created, closing_signal))
304    }
305
306    pub(crate) async fn subscribe(&mut self, topic: K) -> Result<Handle<K, V>, NotifyError<K, V>> {
307        let (sender, receiver) = oneshot::channel();
308
309        self.sender
310            .send(Notification::Subscribe {
311                topic: topic.clone(),
312                response_sender: sender,
313            })
314            .await?;
315
316        let Some((msg_sender, msg_receiver)) = receiver.await? else {
317            return Err(NotifyError::UnknownTopic);
318        };
319        let handle = Handle::new(
320            topic,
321            self.sender.clone(),
322            msg_sender,
323            BroadcastStream::from(msg_receiver),
324        );
325
326        Ok(handle)
327    }
328
329    pub(crate) async fn subscribe_if_exist(
330        &mut self,
331        topic: K,
332    ) -> Result<Option<Handle<K, V>>, NotifyError<K, V>> {
333        let (sender, receiver) = oneshot::channel();
334
335        self.sender
336            .send(Notification::SubscribeIfExist {
337                topic: topic.clone(),
338                response_sender: sender,
339            })
340            .await?;
341
342        let Some((msg_sender, msg_receiver)) = receiver.await? else {
343            return Ok(None);
344        };
345        let handle = Handle::new(
346            topic,
347            self.sender.clone(),
348            msg_sender,
349            BroadcastStream::from(msg_receiver),
350        );
351
352        Ok(handle.into())
353    }
354
355    pub(crate) async fn exist(&mut self, topic: K) -> Result<bool, NotifyError<K, V>> {
356        // Channel to check if the topic still exists or not
357        let (response_tx, response_rx) = oneshot::channel();
358
359        self.sender
360            .send(Notification::Exist {
361                topic,
362                response_sender: response_tx,
363            })
364            .await?;
365
366        let resp = response_rx.await?;
367
368        Ok(resp)
369    }
370
371    pub(crate) async fn invalid_ids(
372        &mut self,
373        topics: Vec<K>,
374    ) -> Result<(Vec<K>, Vec<K>), NotifyError<K, V>> {
375        // Channel to check if the topic still exists or not
376        let (response_tx, response_rx) = oneshot::channel();
377
378        self.sender
379            .send(Notification::InvalidIds {
380                topics,
381                response_sender: response_tx,
382            })
383            .await?;
384
385        let resp = response_rx.await?;
386
387        Ok(resp)
388    }
389
390    /// Delete the topic even if several subscribers are still listening
391    pub(crate) async fn force_delete(&mut self, topic: K) -> Result<(), NotifyError<K, V>> {
392        // if disconnected, we don't care (the task was stopped)
393        self.sender
394            .send(Notification::ForceDelete { topic })
395            .await
396            .map_err(std::convert::Into::into)
397    }
398
399    /// Delete the topic if and only if one or zero subscriber is still listening
400    /// This function is not async to allow it to be used in a Drop impl
401    #[cfg(test)]
402    pub(crate) fn try_delete(&mut self, topic: K) -> Result<(), NotifyError<K, V>> {
403        // if disconnected, we don't care (the task was stopped)
404        self.sender
405            .try_send(Notification::TryDelete { topic })
406            .map_err(|try_send_error| try_send_error.into())
407    }
408
409    #[cfg(test)]
410    pub(crate) async fn broadcast(&mut self, data: V) -> Result<(), NotifyError<K, V>> {
411        self.sender
412            .send(Notification::Broadcast { data })
413            .await
414            .map_err(std::convert::Into::into)
415    }
416
417    #[cfg(test)]
418    pub(crate) async fn debug(&mut self) -> Result<usize, NotifyError<K, V>> {
419        let (response_tx, response_rx) = oneshot::channel();
420        self.sender
421            .send(Notification::Debug {
422                response_sender: response_tx,
423            })
424            .await?;
425
426        Ok(response_rx.await.unwrap())
427    }
428}
429
430#[cfg(test)]
431impl<K, V> Default for Notify<K, V>
432where
433    K: Send + Hash + Eq + Clone + 'static,
434    V: Send + Sync + Clone + 'static,
435{
436    /// Useless notify mainly for test
437    fn default() -> Self {
438        Self::for_tests()
439    }
440}
441
442impl<K, V> Debug for Notify<K, V> {
443    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444        f.debug_struct("Notify").finish()
445    }
446}
447
448struct HandleGuard<K, V>
449where
450    K: Clone,
451{
452    topic: Arc<K>,
453    pubsub_sender: mpsc::Sender<Notification<K, V>>,
454}
455
456impl<K, V> Clone for HandleGuard<K, V>
457where
458    K: Clone,
459{
460    fn clone(&self) -> Self {
461        Self {
462            topic: self.topic.clone(),
463            pubsub_sender: self.pubsub_sender.clone(),
464        }
465    }
466}
467
468impl<K, V> Drop for HandleGuard<K, V>
469where
470    K: Clone,
471{
472    fn drop(&mut self) {
473        let err = self.pubsub_sender.try_send(Notification::Unsubscribe {
474            topic: self.topic.as_ref().clone(),
475        });
476        if let Err(err) = err {
477            tracing::trace!("cannot unsubscribe {err:?}");
478        }
479    }
480}
481
482pin_project! {
483pub struct Handle<K, V>
484where
485    K: Clone,
486{
487    handle_guard: HandleGuard<K, V>,
488    #[pin]
489    msg_sender: broadcast::Sender<Option<V>>,
490    #[pin]
491    msg_receiver: BroadcastStream<Option<V>>,
492}
493}
494
495impl<K, V> Clone for Handle<K, V>
496where
497    K: Clone,
498    V: Clone + Send + 'static,
499{
500    fn clone(&self) -> Self {
501        Self {
502            handle_guard: self.handle_guard.clone(),
503            msg_receiver: BroadcastStream::new(self.msg_sender.subscribe()),
504            msg_sender: self.msg_sender.clone(),
505        }
506    }
507}
508
509impl<K, V> Handle<K, V>
510where
511    K: Clone,
512{
513    fn new(
514        topic: K,
515        pubsub_sender: mpsc::Sender<Notification<K, V>>,
516        msg_sender: broadcast::Sender<Option<V>>,
517        msg_receiver: BroadcastStream<Option<V>>,
518    ) -> Self {
519        Self {
520            handle_guard: HandleGuard {
521                topic: Arc::new(topic),
522                pubsub_sender,
523            },
524            msg_sender,
525            msg_receiver,
526        }
527    }
528
529    pub(crate) fn into_stream(self) -> HandleStream<K, V> {
530        HandleStream {
531            handle_guard: self.handle_guard,
532            msg_receiver: self.msg_receiver,
533        }
534    }
535
536    pub(crate) fn into_sink(self) -> HandleSink<K, V> {
537        HandleSink {
538            handle_guard: self.handle_guard,
539            msg_sender: self.msg_sender,
540        }
541    }
542
543    /// Return a sink and a stream
544    pub fn split(self) -> (HandleSink<K, V>, HandleStream<K, V>) {
545        (
546            HandleSink {
547                handle_guard: self.handle_guard.clone(),
548                msg_sender: self.msg_sender,
549            },
550            HandleStream {
551                handle_guard: self.handle_guard,
552                msg_receiver: self.msg_receiver,
553            },
554        )
555    }
556}
557
558pin_project! {
559pub struct HandleStream<K, V>
560where
561    K: Clone,
562{
563    handle_guard: HandleGuard<K, V>,
564    #[pin]
565    msg_receiver: BroadcastStream<Option<V>>,
566}
567}
568
569impl<K, V> Stream for HandleStream<K, V>
570where
571    K: Clone,
572    V: Clone + 'static + Send,
573{
574    type Item = V;
575
576    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
577        let mut this = self.as_mut().project();
578
579        match Pin::new(&mut this.msg_receiver).poll_next(cx) {
580            Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(_)))) => {
581                u64_counter!(
582                    "apollo.router.skipped.event.count",
583                    "Amount of events dropped from the internal message queue",
584                    1u64
585                );
586                self.poll_next(cx)
587            }
588            Poll::Ready(None) => Poll::Ready(None),
589            Poll::Ready(Some(Ok(Some(val)))) => Poll::Ready(Some(val)),
590            Poll::Ready(Some(Ok(None))) => Poll::Ready(None),
591            Poll::Pending => Poll::Pending,
592        }
593    }
594}
595
596pin_project! {
597pub struct HandleSink<K, V>
598where
599    K: Clone,
600{
601    handle_guard: HandleGuard<K, V>,
602    #[pin]
603    msg_sender: broadcast::Sender<Option<V>>,
604}
605}
606
607impl<K, V> HandleSink<K, V>
608where
609    K: Clone,
610    V: Clone + 'static + Send,
611{
612    /// Send data to the subscribed topic
613    pub(crate) fn send_sync(&mut self, data: V) -> Result<(), NotifyError<K, V>> {
614        self.msg_sender.send(data.into()).map_err(|err| {
615            NotifyError::BroadcastSendError(broadcast::error::SendError(err.0.unwrap()))
616        })?;
617
618        Ok(())
619    }
620}
621
622impl<K, V> Sink<V> for HandleSink<K, V>
623where
624    K: Clone,
625    V: Clone + 'static + Send,
626{
627    type Error = graphql::Error;
628
629    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
630        Poll::Ready(Ok(()))
631    }
632
633    fn start_send(self: Pin<&mut Self>, item: V) -> Result<(), Self::Error> {
634        self.msg_sender.send(Some(item)).map_err(|_err| {
635            graphql::Error::builder()
636                .message("cannot send payload through pubsub")
637                .extension_code("NOTIFICATION_HANDLE_SEND_ERROR")
638                .build()
639        })?;
640        Ok(())
641    }
642
643    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
644        Poll::Ready(Ok(()))
645    }
646
647    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
648        let topic = self.handle_guard.topic.as_ref().clone();
649        let _ = self
650            .handle_guard
651            .pubsub_sender
652            .try_send(Notification::ForceDelete { topic });
653        Poll::Ready(Ok(()))
654    }
655}
656
657impl<K, V> Handle<K, V> where K: Clone {}
658
659async fn task<K, V>(
660    mut receiver: ReceiverStream<Notification<K, V>>,
661    mut ttl: Option<Duration>,
662    heartbeat_error_message: Option<V>,
663) where
664    K: Send + Hash + Eq + Clone + 'static,
665    V: Send + Clone + 'static,
666{
667    let mut pubsub: PubSub<K, V> = PubSub::new(ttl);
668
669    let mut ttl_fut: Box<dyn Stream<Item = tokio::time::Instant> + Send + Unpin> = match ttl {
670        Some(ttl) => Box::new(IntervalStream::new(tokio::time::interval(ttl))),
671        None => Box::new(tokio_stream::pending()),
672    };
673
674    loop {
675        tokio::select! {
676            _ = ttl_fut.next() => {
677                let heartbeat_error_message = heartbeat_error_message.clone();
678                pubsub.kill_dead_topics(heartbeat_error_message).await;
679            }
680            message = receiver.next() => {
681                match message {
682                    Some(message) => {
683                        match message {
684                            Notification::Unsubscribe { topic } => pubsub.unsubscribe(topic),
685                            Notification::ForceDelete { topic } => pubsub.force_delete(topic),
686                            Notification::CreateOrSubscribe { topic,  msg_sender, response_sender, heartbeat_enabled, operation_name } => {
687                                pubsub.subscribe_or_create(topic, msg_sender, response_sender, heartbeat_enabled, operation_name);
688                            }
689                            Notification::Subscribe {
690                                topic,
691                                response_sender,
692                            } => {
693                                pubsub.subscribe(topic, response_sender);
694                            }
695                            Notification::SubscribeIfExist {
696                                topic,
697                                response_sender,
698                            } => {
699                                if pubsub.is_used(&topic) {
700                                    pubsub.subscribe(topic, response_sender);
701                                } else {
702                                    pubsub.force_delete(topic);
703                                    let _ = response_sender.send(None);
704                                }
705                            }
706                            Notification::InvalidIds {
707                                topics,
708                                response_sender,
709                            } => {
710                                let invalid_topics = pubsub.invalid_topics(topics);
711                                let _ = response_sender.send(invalid_topics);
712                            }
713                            Notification::UpdateHeartbeat {
714                                mut new_ttl
715                            } => {
716                                // We accept to miss max 3 heartbeats before cutting the connection
717                                new_ttl = new_ttl.map(|ttl| ttl * 3);
718                                if ttl != new_ttl {
719                                    ttl = new_ttl;
720                                    pubsub.ttl = new_ttl;
721                                    match new_ttl {
722                                        Some(new_ttl) => {
723                                            ttl_fut = Box::new(IntervalStream::new(tokio::time::interval(new_ttl)));
724                                        },
725                                        None => {
726                                            ttl_fut = Box::new(tokio_stream::pending());
727                                        }
728                                    }
729                                }
730
731                            }
732                            Notification::Exist {
733                                topic,
734                                response_sender,
735                            } => {
736                                let exist = pubsub.exist(&topic);
737                                let _ = response_sender.send(exist);
738                                if exist {
739                                    pubsub.touch(&topic);
740                                }
741                            }
742                            #[cfg(test)]
743                            Notification::TryDelete { topic } => pubsub.try_delete(topic),
744                            #[cfg(test)]
745                            Notification::Broadcast { data } => {
746                                pubsub.broadcast(data).await;
747                            }
748                            #[cfg(test)]
749                            Notification::Debug { response_sender } => {
750                                let _ = response_sender.send(pubsub.subscriptions.len());
751                            }
752                        }
753                    },
754                    None => break,
755                }
756            }
757        }
758    }
759}
760
761#[derive(Debug)]
762struct Subscription<V> {
763    msg_sender: broadcast::Sender<Option<V>>,
764    closing_signal: broadcast::Sender<()>,
765    heartbeat_enabled: bool,
766    updated_at: Instant,
767    _metric_guard: UpDownCounterGuard<i64>,
768}
769
770impl<V> Subscription<V> {
771    fn new(
772        msg_sender: broadcast::Sender<Option<V>>,
773        closing_signal: broadcast::Sender<()>,
774        heartbeat_enabled: bool,
775        operation_name: Option<String>,
776    ) -> Self {
777        let metric_guard = i64_up_down_counter!(
778            "apollo.router.opened.subscriptions",
779            "Number of opened subscriptions",
780            1,
781            graphql.operation.name = operation_name.unwrap_or_default()
782        );
783
784        Self {
785            msg_sender,
786            closing_signal,
787            heartbeat_enabled,
788            updated_at: Instant::now(),
789            _metric_guard: metric_guard,
790        }
791    }
792    // Update the updated_at value
793    fn touch(&mut self) {
794        self.updated_at = Instant::now();
795    }
796
797    fn closing_signal(&self) -> broadcast::Receiver<()> {
798        self.closing_signal.subscribe()
799    }
800}
801
802struct PubSub<K, V>
803where
804    K: Hash + Eq,
805{
806    subscriptions: HashMap<K, Subscription<V>>,
807    ttl: Option<Duration>,
808}
809
810impl<K, V> Default for PubSub<K, V>
811where
812    K: Hash + Eq,
813{
814    fn default() -> Self {
815        Self {
816            // subscribers: HashMap::new(),
817            subscriptions: HashMap::new(),
818            ttl: None,
819        }
820    }
821}
822
823impl<K, V> PubSub<K, V>
824where
825    K: Hash + Eq + Clone,
826    V: Clone + 'static,
827{
828    fn new(ttl: Option<Duration>) -> Self {
829        Self {
830            subscriptions: HashMap::new(),
831            ttl,
832        }
833    }
834
835    fn create_topic(
836        &mut self,
837        topic: K,
838        sender: broadcast::Sender<Option<V>>,
839        heartbeat_enabled: bool,
840        operation_name: Option<String>,
841    ) -> broadcast::Receiver<()> {
842        let (closing_signal_tx, closing_signal_rx) = broadcast::channel(1);
843        self.subscriptions.insert(
844            topic,
845            Subscription::new(
846                sender,
847                closing_signal_tx,
848                heartbeat_enabled,
849                operation_name.clone(),
850            ),
851        );
852
853        closing_signal_rx
854    }
855
856    fn subscribe(&mut self, topic: K, sender: ResponseSender<V>) {
857        match self.subscriptions.get_mut(&topic) {
858            Some(subscription) => {
859                let _ = sender.send(Some((
860                    subscription.msg_sender.clone(),
861                    subscription.msg_sender.subscribe(),
862                )));
863            }
864            None => {
865                let _ = sender.send(None);
866            }
867        }
868    }
869
870    fn subscribe_or_create(
871        &mut self,
872        topic: K,
873        msg_sender: broadcast::Sender<Option<V>>,
874        sender: ResponseSenderWithCreated<V>,
875        heartbeat_enabled: bool,
876        operation_name: Option<String>,
877    ) {
878        match self.subscriptions.get(&topic) {
879            Some(subscription) => {
880                let _ = sender.send(CreatedTopicPayload {
881                    msg_sender: subscription.msg_sender.clone(),
882                    msg_receiver: subscription.msg_sender.subscribe(),
883                    closing_signal: subscription.closing_signal(),
884                    created: false,
885                });
886            }
887            None => {
888                let closing_signal =
889                    self.create_topic(topic, msg_sender.clone(), heartbeat_enabled, operation_name);
890
891                let _ = sender.send(CreatedTopicPayload {
892                    msg_sender: msg_sender.clone(),
893                    msg_receiver: msg_sender.subscribe(),
894                    closing_signal,
895                    created: true,
896                });
897            }
898        }
899    }
900
901    fn unsubscribe(&mut self, topic: K) {
902        let mut topic_to_delete = false;
903        match self.subscriptions.get(&topic) {
904            Some(subscription) => {
905                topic_to_delete = subscription.msg_sender.receiver_count() == 0;
906            }
907            None => tracing::trace!("Cannot find the subscription to unsubscribe"),
908        }
909        if topic_to_delete {
910            tracing::trace!("deleting subscription from unsubscribe");
911            self.force_delete(topic);
912        };
913    }
914
915    /// Check if the topic is used by anyone else than the current handle
916    fn is_used(&self, topic: &K) -> bool {
917        self.subscriptions
918            .get(topic)
919            .map(|s| s.msg_sender.receiver_count() > 0)
920            .unwrap_or_default()
921    }
922
923    /// Update the heartbeat
924    fn touch(&mut self, topic: &K) {
925        if let Some(sub) = self.subscriptions.get_mut(topic) {
926            sub.touch();
927        }
928    }
929
930    /// Check if the topic exists
931    fn exist(&self, topic: &K) -> bool {
932        self.subscriptions.contains_key(topic)
933    }
934
935    /// Given a list of topics, returns the list of valid and invalid topics
936    /// Heartbeat the given valid topics
937    fn invalid_topics(&mut self, topics: Vec<K>) -> (Vec<K>, Vec<K>) {
938        topics.into_iter().fold(
939            (Vec::new(), Vec::new()),
940            |(mut valid_ids, mut invalid_ids), e| {
941                match self.subscriptions.get_mut(&e) {
942                    Some(sub) => {
943                        sub.touch();
944                        valid_ids.push(e);
945                    }
946                    None => {
947                        invalid_ids.push(e);
948                    }
949                }
950
951                (valid_ids, invalid_ids)
952            },
953        )
954    }
955
956    /// clean all topics which didn't heartbeat
957    async fn kill_dead_topics(&mut self, heartbeat_error_message: Option<V>) {
958        if let Some(ttl) = self.ttl {
959            let drained = self.subscriptions.drain();
960            let (remaining_subs, closed_subs) = drained.into_iter().fold(
961                (HashMap::new(), HashMap::new()),
962                |(mut acc, mut acc_error), (topic, sub)| {
963                    if (!sub.heartbeat_enabled || sub.updated_at.elapsed() <= ttl)
964                        && sub.msg_sender.receiver_count() > 0
965                    {
966                        acc.insert(topic, sub);
967                    } else {
968                        acc_error.insert(topic, sub);
969                    }
970
971                    (acc, acc_error)
972                },
973            );
974            self.subscriptions = remaining_subs;
975
976            // Send error message to all killed connections
977            for (_, subscription) in closed_subs {
978                tracing::trace!("deleting subscription from kill_dead_topics");
979                self._force_delete(subscription, heartbeat_error_message.as_ref());
980            }
981        }
982    }
983
984    #[cfg(test)]
985    fn try_delete(&mut self, topic: K) {
986        if let Some(sub) = self.subscriptions.get(&topic)
987            && sub.msg_sender.receiver_count() > 1
988        {
989            return;
990        }
991
992        self.force_delete(topic);
993    }
994
995    fn force_delete(&mut self, topic: K) {
996        tracing::trace!("deleting subscription from force_delete");
997        let sub = self.subscriptions.remove(&topic);
998        if let Some(sub) = sub {
999            self._force_delete(sub, None);
1000        }
1001    }
1002
1003    fn _force_delete(&mut self, sub: Subscription<V>, error_message: Option<&V>) {
1004        tracing::trace!("deleting subscription from _force_delete");
1005        if let Some(error_message) = error_message {
1006            let _ = sub.msg_sender.send(error_message.clone().into());
1007        }
1008        let _ = sub.msg_sender.send(None);
1009        let _ = sub.closing_signal.send(());
1010    }
1011
1012    #[cfg(test)]
1013    async fn broadcast(&mut self, value: V) -> Option<()>
1014    where
1015        V: Clone,
1016    {
1017        let mut fut = vec![];
1018        for (sub_id, sub) in &self.subscriptions {
1019            let cloned_value = value.clone();
1020            let sub_id = sub_id.clone();
1021            fut.push(
1022                sub.msg_sender
1023                    .send(cloned_value.into())
1024                    .is_err()
1025                    .then_some(sub_id),
1026            );
1027        }
1028        // clean closed sender
1029        let sub_to_clean: Vec<K> = fut.into_iter().flatten().collect();
1030        self.subscriptions
1031            .retain(|k, s| s.msg_sender.receiver_count() > 0 && !sub_to_clean.contains(k));
1032
1033        Some(())
1034    }
1035}
1036
1037pub(crate) struct RouterBroadcasts {
1038    configuration: (
1039        broadcast::Sender<Weak<Configuration>>,
1040        broadcast::Receiver<Weak<Configuration>>,
1041    ),
1042    schema: (
1043        broadcast::Sender<Arc<Schema>>,
1044        broadcast::Receiver<Arc<Schema>>,
1045    ),
1046}
1047
1048impl RouterBroadcasts {
1049    pub(crate) fn new() -> Self {
1050        Self {
1051            // Set to 2 to avoid potential deadlock when triggering a config/schema change mutiple times in a row
1052            configuration: broadcast::channel(2),
1053            schema: broadcast::channel(2),
1054        }
1055    }
1056
1057    pub(crate) fn subscribe_configuration(
1058        &self,
1059    ) -> impl Stream<Item = Weak<Configuration>> + use<> {
1060        BroadcastStream::new(self.configuration.0.subscribe())
1061            .filter_map(|cfg| futures::future::ready(cfg.ok()))
1062    }
1063
1064    pub(crate) fn subscribe_schema(&self) -> impl Stream<Item = Arc<Schema>> + use<> {
1065        BroadcastStream::new(self.schema.0.subscribe())
1066            .filter_map(|schema| futures::future::ready(schema.ok()))
1067    }
1068}
1069
1070#[cfg(test)]
1071mod tests {
1072
1073    use futures::FutureExt;
1074    use tokio_stream::StreamExt;
1075    use uuid::Uuid;
1076
1077    use super::*;
1078    use crate::metrics::FutureMetricsExt;
1079
1080    #[tokio::test]
1081    async fn subscribe() {
1082        let mut notify = Notify::builder().build();
1083        let topic_1 = Uuid::new_v4();
1084        let topic_2 = Uuid::new_v4();
1085
1086        let (handle1, created, mut subscription_closing_signal_1) = notify
1087            .create_or_subscribe(topic_1, false, None)
1088            .await
1089            .unwrap();
1090        assert!(created);
1091        let (_handle2, created, mut subscription_closing_signal_2) = notify
1092            .create_or_subscribe(topic_2, false, None)
1093            .await
1094            .unwrap();
1095        assert!(created);
1096
1097        let handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1098        let handle_1_other = notify.subscribe(topic_1).await.unwrap();
1099        let mut cloned_notify = notify.clone();
1100
1101        let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1102        handle
1103            .send_sync(serde_json_bytes::json!({"test": "ok"}))
1104            .unwrap();
1105        drop(handle);
1106        drop(handle1);
1107        let mut handle_1_bis = handle_1_bis.into_stream();
1108        let new_msg = handle_1_bis.next().await.unwrap();
1109        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1110        let mut handle_1_other = handle_1_other.into_stream();
1111        let new_msg = handle_1_other.next().await.unwrap();
1112        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1113
1114        assert!(notify.exist(topic_1).await.unwrap());
1115        assert!(notify.exist(topic_2).await.unwrap());
1116
1117        drop(_handle2);
1118        drop(handle_1_bis);
1119        drop(handle_1_other);
1120
1121        let subscriptions_nb = notify.debug().await.unwrap();
1122        assert_eq!(subscriptions_nb, 0);
1123
1124        subscription_closing_signal_1.try_recv().unwrap();
1125        subscription_closing_signal_2.try_recv().unwrap();
1126    }
1127
1128    #[tokio::test]
1129    async fn it_subscribe_and_delete() {
1130        let mut notify = Notify::builder().build();
1131        let topic_1 = Uuid::new_v4();
1132        let topic_2 = Uuid::new_v4();
1133
1134        let (handle1, created, mut subscription_closing_signal_1) = notify
1135            .create_or_subscribe(topic_1, true, None)
1136            .await
1137            .unwrap();
1138        assert!(created);
1139        let (_handle2, created, mut subscription_closing_signal_2) = notify
1140            .create_or_subscribe(topic_2, true, None)
1141            .await
1142            .unwrap();
1143        assert!(created);
1144
1145        let mut _handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1146        let mut _handle_1_other = notify.subscribe(topic_1).await.unwrap();
1147        let mut cloned_notify = notify.clone();
1148        let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1149        handle
1150            .send_sync(serde_json_bytes::json!({"test": "ok"}))
1151            .unwrap();
1152        drop(handle);
1153        assert!(notify.exist(topic_1).await.unwrap());
1154        drop(_handle_1_bis);
1155        drop(_handle_1_other);
1156
1157        notify.try_delete(topic_1).unwrap();
1158
1159        let subscriptions_nb = notify.debug().await.unwrap();
1160        assert_eq!(subscriptions_nb, 1);
1161
1162        assert!(!notify.exist(topic_1).await.unwrap());
1163
1164        notify.force_delete(topic_1).await.unwrap();
1165
1166        let mut handle1 = handle1.into_stream();
1167        let new_msg = handle1.next().await.unwrap();
1168        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1169        assert!(handle1.next().await.is_none());
1170        assert!(notify.exist(topic_2).await.unwrap());
1171        notify.try_delete(topic_2).unwrap();
1172
1173        let subscriptions_nb = notify.debug().await.unwrap();
1174        assert_eq!(subscriptions_nb, 0);
1175        drop(handle1);
1176        subscription_closing_signal_1.try_recv().unwrap();
1177        subscription_closing_signal_2.try_recv().unwrap();
1178    }
1179
1180    #[tokio::test]
1181    async fn it_subscribe_and_delete_metrics() {
1182        async {
1183            let mut notify = Notify::builder().build();
1184            let topic_1 = Uuid::new_v4();
1185            let topic_2 = Uuid::new_v4();
1186
1187            let (handle1, created, mut subscription_closing_signal_1) = notify
1188                .create_or_subscribe(topic_1, true, Some("TestSubscription".to_string()))
1189                .await
1190                .unwrap();
1191            assert!(created);
1192            let (_handle2, created, mut subscription_closing_signal_2) = notify
1193                .create_or_subscribe(topic_2, true, Some("TestSubscriptionBis".to_string()))
1194                .await
1195                .unwrap();
1196            assert!(created);
1197            assert_up_down_counter!(
1198                "apollo.router.opened.subscriptions",
1199                1i64,
1200                "graphql.operation.name" = "TestSubscription"
1201            );
1202            assert_up_down_counter!(
1203                "apollo.router.opened.subscriptions",
1204                1i64,
1205                "graphql.operation.name" = "TestSubscriptionBis"
1206            );
1207
1208            let mut _handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1209            let mut _handle_1_other = notify.subscribe(topic_1).await.unwrap();
1210            let mut cloned_notify = notify.clone();
1211            let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1212            handle
1213                .send_sync(serde_json_bytes::json!({"test": "ok"}))
1214                .unwrap();
1215            drop(handle);
1216            assert!(notify.exist(topic_1).await.unwrap());
1217            drop(_handle_1_bis);
1218            drop(_handle_1_other);
1219
1220            notify.try_delete(topic_1).unwrap();
1221            assert_up_down_counter!(
1222                "apollo.router.opened.subscriptions",
1223                1i64,
1224                "graphql.operation.name" = "TestSubscription"
1225            );
1226            assert_up_down_counter!(
1227                "apollo.router.opened.subscriptions",
1228                1i64,
1229                "graphql.operation.name" = "TestSubscriptionBis"
1230            );
1231
1232            let subscriptions_nb = notify.debug().await.unwrap();
1233            assert_eq!(subscriptions_nb, 1);
1234
1235            assert!(!notify.exist(topic_1).await.unwrap());
1236
1237            notify.force_delete(topic_1).await.unwrap();
1238            assert_up_down_counter!(
1239                "apollo.router.opened.subscriptions",
1240                0i64,
1241                "graphql.operation.name" = "TestSubscription"
1242            );
1243            assert_up_down_counter!(
1244                "apollo.router.opened.subscriptions",
1245                1i64,
1246                "graphql.operation.name" = "TestSubscriptionBis"
1247            );
1248
1249            let mut handle1 = handle1.into_stream();
1250            let new_msg = handle1.next().await.unwrap();
1251            assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1252            assert!(handle1.next().await.is_none());
1253            assert!(notify.exist(topic_2).await.unwrap());
1254            notify.try_delete(topic_2).unwrap();
1255
1256            let subscriptions_nb = notify.debug().await.unwrap();
1257            assert_eq!(subscriptions_nb, 0);
1258            assert_up_down_counter!(
1259                "apollo.router.opened.subscriptions",
1260                0i64,
1261                "graphql.operation.name" = "TestSubscription"
1262            );
1263            assert_up_down_counter!(
1264                "apollo.router.opened.subscriptions",
1265                0i64,
1266                "graphql.operation.name" = "TestSubscriptionBis"
1267            );
1268            subscription_closing_signal_1.try_recv().unwrap();
1269            subscription_closing_signal_2.try_recv().unwrap();
1270        }
1271        .with_metrics()
1272        .await;
1273    }
1274
1275    #[tokio::test]
1276    async fn it_test_ttl() {
1277        let mut notify = Notify::builder()
1278            .ttl(Duration::from_millis(300))
1279            .heartbeat_error_message(serde_json_bytes::json!({"error": "connection_closed"}))
1280            .build();
1281        let topic_1 = Uuid::new_v4();
1282        let topic_2 = Uuid::new_v4();
1283
1284        let (handle1, created, mut subscription_closing_signal_1) = notify
1285            .create_or_subscribe(topic_1, true, None)
1286            .await
1287            .unwrap();
1288        assert!(created);
1289        let (_handle2, created, mut subscription_closing_signal_2) = notify
1290            .create_or_subscribe(topic_2, true, None)
1291            .await
1292            .unwrap();
1293        assert!(created);
1294
1295        let handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1296        let handle_1_other = notify.subscribe(topic_1).await.unwrap();
1297        let mut cloned_notify = notify.clone();
1298        tokio::spawn(async move {
1299            let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1300            handle
1301                .send_sync(serde_json_bytes::json!({"test": "ok"}))
1302                .unwrap();
1303        });
1304        drop(handle1);
1305
1306        let mut handle_1_bis = handle_1_bis.into_stream();
1307        let new_msg = handle_1_bis.next().await.unwrap();
1308        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1309        let mut handle_1_other = handle_1_other.into_stream();
1310        let new_msg = handle_1_other.next().await.unwrap();
1311        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1312
1313        notify
1314            .set_ttl(Duration::from_millis(70).into())
1315            .await
1316            .unwrap();
1317
1318        tokio::time::sleep(Duration::from_millis(150)).await;
1319        let mut cloned_notify = notify.clone();
1320        tokio::spawn(async move {
1321            let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1322            handle
1323                .send_sync(serde_json_bytes::json!({"test": "ok"}))
1324                .unwrap();
1325        });
1326        let new_msg = handle_1_bis.next().await.unwrap();
1327        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1328        tokio::time::sleep(Duration::from_millis(150)).await;
1329
1330        let res = handle_1_bis.next().now_or_never().unwrap();
1331        assert_eq!(
1332            res,
1333            Some(serde_json_bytes::json!({"error": "connection_closed"}))
1334        );
1335
1336        assert!(handle_1_bis.next().await.is_none());
1337
1338        assert!(!notify.exist(topic_1).await.unwrap());
1339        assert!(!notify.exist(topic_2).await.unwrap());
1340        subscription_closing_signal_1.try_recv().unwrap();
1341        subscription_closing_signal_2.try_recv().unwrap();
1342
1343        let subscriptions_nb = notify.debug().await.unwrap();
1344        assert_eq!(subscriptions_nb, 0);
1345    }
1346}