Skip to main content

apollo_router/plugins/subscription/
notification.rs

1//! Internal pub/sub facility for subscription
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::hash::Hash;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::sync::Weak;
8use std::task::Context;
9use std::task::Poll;
10use std::time::Duration;
11use std::time::Instant;
12
13use futures::Sink;
14use futures::Stream;
15use futures::StreamExt;
16use pin_project_lite::pin_project;
17use thiserror::Error;
18use tokio::sync::broadcast;
19use tokio::sync::mpsc;
20use tokio::sync::mpsc::error::SendError;
21use tokio::sync::mpsc::error::TrySendError;
22use tokio::sync::oneshot;
23use tokio::sync::oneshot::error::RecvError;
24use tokio_stream::wrappers::BroadcastStream;
25use tokio_stream::wrappers::IntervalStream;
26use tokio_stream::wrappers::ReceiverStream;
27use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
28
29use crate::Configuration;
30use crate::allocator::WithMemoryTracking;
31use crate::graphql;
32use crate::metrics::FutureMetricsExt;
33use crate::metrics::UpDownCounterGuard;
34use crate::spec::Schema;
35
36static NOTIFY_CHANNEL_SIZE: usize = 1024;
37static DEFAULT_MSG_CHANNEL_SIZE: usize = 128;
38
39#[derive(Error, Debug)]
40pub(crate) enum NotifyError<K, V> {
41    #[error("cannot receive data from pubsub")]
42    RecvError(#[from] RecvError),
43    #[error("cannot send data to pubsub")]
44    SendError(#[from] SendError<V>),
45    #[error("cannot send data to pubsub")]
46    NotificationSendError(#[from] SendError<Notification<K, V>>),
47    #[error("cannot send data to pubsub")]
48    NotificationTrySendError(#[from] TrySendError<Notification<K, V>>),
49    #[error("cannot send data to response stream")]
50    BroadcastSendError(#[from] broadcast::error::SendError<V>),
51    #[error("this topic doesn't exist")]
52    UnknownTopic,
53}
54
55type ResponseSender<V> =
56    oneshot::Sender<Option<(broadcast::Sender<Option<V>>, broadcast::Receiver<Option<V>>)>>;
57
58pub(crate) struct CreatedTopicPayload<V> {
59    msg_sender: broadcast::Sender<Option<V>>,
60    msg_receiver: broadcast::Receiver<Option<V>>,
61    closing_signal: broadcast::Receiver<()>,
62    created: bool,
63}
64
65type ResponseSenderWithCreated<V> = oneshot::Sender<CreatedTopicPayload<V>>;
66
67pub(crate) enum Notification<K, V> {
68    CreateOrSubscribe {
69        topic: K,
70        // Sender connected to the original source stream
71        msg_sender: broadcast::Sender<Option<V>>,
72        // To know if it has been created or re-used
73        response_sender: ResponseSenderWithCreated<V>,
74        heartbeat_enabled: bool,
75        // Useful for the metric we create
76        operation_name: Option<String>,
77    },
78    Subscribe {
79        topic: K,
80        // Oneshot channel to fetch the receiver
81        response_sender: ResponseSender<V>,
82    },
83    SubscribeIfExist {
84        topic: K,
85        // Oneshot channel to fetch the receiver
86        response_sender: ResponseSender<V>,
87    },
88    Unsubscribe {
89        topic: K,
90    },
91    ForceDelete {
92        topic: K,
93    },
94    Exist {
95        topic: K,
96        response_sender: oneshot::Sender<bool>,
97    },
98    InvalidIds {
99        topics: Vec<K>,
100        response_sender: oneshot::Sender<(Vec<K>, Vec<K>)>,
101    },
102    UpdateHeartbeat {
103        new_ttl: Option<Duration>,
104    },
105    #[cfg(test)]
106    TryDelete {
107        topic: K,
108    },
109    #[cfg(test)]
110    Broadcast {
111        data: V,
112    },
113    #[cfg(test)]
114    Debug {
115        // Returns the number of subscriptions and subscribers
116        response_sender: oneshot::Sender<usize>,
117    },
118}
119
120impl<K, V> Debug for Notification<K, V> {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        match self {
123            Self::CreateOrSubscribe { .. } => f.debug_struct("CreateOrSubscribe").finish(),
124            Self::Subscribe { .. } => f.debug_struct("Subscribe").finish(),
125            Self::SubscribeIfExist { .. } => f.debug_struct("SubscribeIfExist").finish(),
126            Self::Unsubscribe { .. } => f.debug_struct("Unsubscribe").finish(),
127            Self::ForceDelete { .. } => f.debug_struct("ForceDelete").finish(),
128            Self::Exist { .. } => f.debug_struct("Exist").finish(),
129            Self::InvalidIds { .. } => f.debug_struct("InvalidIds").finish(),
130            Self::UpdateHeartbeat { .. } => f.debug_struct("UpdateHeartbeat").finish(),
131            #[cfg(test)]
132            Self::TryDelete { .. } => f.debug_struct("TryDelete").finish(),
133            #[cfg(test)]
134            Self::Broadcast { .. } => f.debug_struct("Broadcast").finish(),
135            #[cfg(test)]
136            Self::Debug { .. } => f.debug_struct("Debug").finish(),
137        }
138    }
139}
140
141/// In memory pub/sub implementation
142#[derive(Clone)]
143pub struct Notify<K, V> {
144    sender: mpsc::Sender<Notification<K, V>>,
145    /// Size (number of events) of the channel to receive message
146    pub(crate) queue_size: Option<usize>,
147    router_broadcasts: Arc<RouterBroadcasts>,
148}
149
150#[buildstructor::buildstructor]
151impl<K, V> Notify<K, V>
152where
153    K: Send + Hash + Eq + Clone + 'static,
154    V: Send + Sync + Clone + 'static,
155{
156    #[builder]
157    pub(crate) fn new(
158        ttl: Option<Duration>,
159        heartbeat_error_message: Option<V>,
160        queue_size: Option<usize>,
161    ) -> Notify<K, V> {
162        let (sender, receiver) = mpsc::channel(NOTIFY_CHANNEL_SIZE);
163        let receiver_stream: ReceiverStream<Notification<K, V>> = ReceiverStream::new(receiver);
164        tokio::task::spawn(
165            task(receiver_stream, ttl, heartbeat_error_message)
166                .with_current_meter_provider()
167                .with_memory_tracking("subscription.task"),
168        );
169        Notify {
170            sender,
171            queue_size,
172            router_broadcasts: Arc::new(RouterBroadcasts::new()),
173        }
174    }
175
176    #[doc(hidden)]
177    /// NOOP notifier for tests
178    pub fn for_tests() -> Self {
179        let (sender, _receiver) = mpsc::channel(NOTIFY_CHANNEL_SIZE);
180        Notify {
181            sender,
182            queue_size: None,
183            router_broadcasts: Arc::new(RouterBroadcasts::new()),
184        }
185    }
186}
187
188impl<K, V> Notify<K, V> {
189    /// Broadcast a new configuration
190    pub(crate) fn broadcast_configuration(&self, configuration: Weak<Configuration>) {
191        self.router_broadcasts.configuration.0.send(configuration).expect("cannot send the configuration update to the static channel. Should not happen because the receiver will always live in this struct; qed");
192    }
193    /// Receive the new configuration everytime we have a new router configuration
194    pub(crate) fn subscribe_configuration(
195        &self,
196    ) -> impl Stream<Item = Weak<Configuration>> + use<K, V> {
197        self.router_broadcasts.subscribe_configuration()
198    }
199    /// Receive the new schema everytime we have a new schema
200    pub(crate) fn broadcast_schema(&self, schema: Arc<Schema>) {
201        self.router_broadcasts.schema.0.send(schema).expect("cannot send the schema update to the static channel. Should not happen because the receiver will always live in this struct; qed");
202    }
203    /// Receive the new schema everytime we have a new schema
204    pub(crate) fn subscribe_schema(&self) -> impl Stream<Item = Arc<Schema>> + use<K, V> {
205        self.router_broadcasts.subscribe_schema()
206    }
207}
208
209impl<K, V> Notify<K, V>
210where
211    K: Send + Hash + Eq + Clone + 'static,
212    V: Send + Clone + 'static,
213{
214    pub(crate) async fn set_ttl(&self, new_ttl: Option<Duration>) -> Result<(), NotifyError<K, V>> {
215        self.sender
216            .send(Notification::UpdateHeartbeat { new_ttl })
217            .await?;
218
219        Ok(())
220    }
221
222    /// Creates or subscribes to a topic, returning a handle and subscription state.
223    ///
224    /// The `Ok()` branch of the `Result` is a tuple, where:
225    ///     - .0: a `Handle` on the subscription event listener,
226    ///     - .1: a boolean, where
227    ///              - `true`: call to this fn `created` this subscription, and
228    ///              - `false`: call to this fn was for a deduplicated subscription
229    ///                         i.e. subscription already exists,
230    ///     - .2: a closing signal in a form of `broadcast::Receiver` that gets
231    ///            triggered once the subscription is closed.
232    ///
233    /// # Closing Signal Usage
234    ///
235    /// The closing signal's usage depends on how subscriptions are managed:
236    ///
237    /// ## Callback Mode (HTTP-based subscriptions)
238    /// - The closing signal is typically **unused** as there are no long-running
239    ///   forwarding tasks to clean up
240    /// - Subscriptions are managed via HTTP callbacks to a public URL
241    /// - Subscription lifecycle is controlled through HTTP responses (404 closes the subscription)
242    /// - Always called with `heartbeat_enabled = true` to enable TTL-based timeout checking
243    ///
244    /// ## Passthrough Mode (WebSocket-based subscriptions)  
245    /// - The closing signal **must be monitored** by the forwarding task using `tokio::select!`
246    /// - Maintains persistent WebSocket connections to subgraphs
247    /// - Needed for proper cleanup when subscriptions are terminated, especially important
248    ///   for deduplication (multiple clients may share one subgraph connection)
249    /// - Always called with `heartbeat_enabled = false` as WebSockets have their own
250    ///   connection management
251    ///
252    /// # Parameters
253    /// - `topic`: The subscription topic identifier
254    /// - `heartbeat_enabled`: Controls TTL-based timeout checking at the notification layer:
255    ///   - `true`: Enables TTL checking. For callback mode, subscriptions will timeout if
256    ///     no heartbeat is received within the TTL period. The actual heartbeat interval
257    ///     is configured separately and sent to subgraphs in the subscription extension.
258    ///     When subgraphs send heartbeat messages, they're processed via `invalid_ids()`
259    ///     which calls `touch()` to update the subscription's `updated_at` timestamp.
260    ///   - `false`: Disables TTL checking (used by passthrough/WebSocket mode)
261    /// - `operation_name`: Optional GraphQL operation name for metrics
262    ///
263    /// # Heartbeat Processing for Callback Mode
264    ///
265    /// When callback mode is configured with a heartbeat interval:
266    /// 1. The interval is converted to milliseconds and sent to the subgraph as
267    ///    `heartbeat_interval_ms` in the subscription extension
268    /// 2. Subgraphs send periodic heartbeat callbacks with subscription IDs
269    /// 3. The heartbeat handler validates IDs and calls `notify.invalid_ids()`
270    /// 4. This updates each valid subscription's timestamp via `touch()`
271    /// 5. The TTL checker uses these timestamps to determine if subscriptions are alive
272    ///    and closes those that haven't been touched within the TTL period
273    pub(crate) async fn create_or_subscribe(
274        &mut self,
275        topic: K,
276        heartbeat_enabled: bool,
277        operation_name: Option<String>,
278    ) -> Result<(Handle<K, V>, bool, broadcast::Receiver<()>), NotifyError<K, V>> {
279        let (sender, _receiver) =
280            broadcast::channel(self.queue_size.unwrap_or(DEFAULT_MSG_CHANNEL_SIZE));
281
282        let (tx, rx) = oneshot::channel();
283        self.sender
284            .send(Notification::CreateOrSubscribe {
285                topic: topic.clone(),
286                msg_sender: sender,
287                response_sender: tx,
288                heartbeat_enabled,
289                operation_name,
290            })
291            .await?;
292
293        let CreatedTopicPayload {
294            msg_sender,
295            msg_receiver,
296            closing_signal,
297            created,
298        } = rx.await?;
299        let handle = Handle::new(
300            topic,
301            self.sender.clone(),
302            msg_sender,
303            BroadcastStream::from(msg_receiver),
304        );
305
306        Ok((handle, created, closing_signal))
307    }
308
309    pub(crate) async fn subscribe(&mut self, topic: K) -> Result<Handle<K, V>, NotifyError<K, V>> {
310        let (sender, receiver) = oneshot::channel();
311
312        self.sender
313            .send(Notification::Subscribe {
314                topic: topic.clone(),
315                response_sender: sender,
316            })
317            .await?;
318
319        let Some((msg_sender, msg_receiver)) = receiver.await? else {
320            return Err(NotifyError::UnknownTopic);
321        };
322        let handle = Handle::new(
323            topic,
324            self.sender.clone(),
325            msg_sender,
326            BroadcastStream::from(msg_receiver),
327        );
328
329        Ok(handle)
330    }
331
332    pub(crate) async fn subscribe_if_exist(
333        &mut self,
334        topic: K,
335    ) -> Result<Option<Handle<K, V>>, NotifyError<K, V>> {
336        let (sender, receiver) = oneshot::channel();
337
338        self.sender
339            .send(Notification::SubscribeIfExist {
340                topic: topic.clone(),
341                response_sender: sender,
342            })
343            .await?;
344
345        let Some((msg_sender, msg_receiver)) = receiver.await? else {
346            return Ok(None);
347        };
348        let handle = Handle::new(
349            topic,
350            self.sender.clone(),
351            msg_sender,
352            BroadcastStream::from(msg_receiver),
353        );
354
355        Ok(handle.into())
356    }
357
358    pub(crate) async fn exist(&mut self, topic: K) -> Result<bool, NotifyError<K, V>> {
359        // Channel to check if the topic still exists or not
360        let (response_tx, response_rx) = oneshot::channel();
361
362        self.sender
363            .send(Notification::Exist {
364                topic,
365                response_sender: response_tx,
366            })
367            .await?;
368
369        let resp = response_rx.await?;
370
371        Ok(resp)
372    }
373
374    pub(crate) async fn invalid_ids(
375        &mut self,
376        topics: Vec<K>,
377    ) -> Result<(Vec<K>, Vec<K>), NotifyError<K, V>> {
378        // Channel to check if the topic still exists or not
379        let (response_tx, response_rx) = oneshot::channel();
380
381        self.sender
382            .send(Notification::InvalidIds {
383                topics,
384                response_sender: response_tx,
385            })
386            .await?;
387
388        let resp = response_rx.await?;
389
390        Ok(resp)
391    }
392
393    /// Delete the topic even if several subscribers are still listening
394    pub(crate) async fn force_delete(&mut self, topic: K) -> Result<(), NotifyError<K, V>> {
395        // if disconnected, we don't care (the task was stopped)
396        self.sender
397            .send(Notification::ForceDelete { topic })
398            .await
399            .map_err(std::convert::Into::into)
400    }
401
402    /// Delete the topic if and only if one or zero subscriber is still listening
403    /// This function is not async to allow it to be used in a Drop impl
404    #[cfg(test)]
405    pub(crate) fn try_delete(&mut self, topic: K) -> Result<(), NotifyError<K, V>> {
406        // if disconnected, we don't care (the task was stopped)
407        self.sender
408            .try_send(Notification::TryDelete { topic })
409            .map_err(|try_send_error| try_send_error.into())
410    }
411
412    #[cfg(test)]
413    pub(crate) async fn broadcast(&mut self, data: V) -> Result<(), NotifyError<K, V>> {
414        self.sender
415            .send(Notification::Broadcast { data })
416            .await
417            .map_err(std::convert::Into::into)
418    }
419
420    #[cfg(test)]
421    pub(crate) async fn debug(&mut self) -> Result<usize, NotifyError<K, V>> {
422        let (response_tx, response_rx) = oneshot::channel();
423        self.sender
424            .send(Notification::Debug {
425                response_sender: response_tx,
426            })
427            .await?;
428
429        Ok(response_rx.await.unwrap())
430    }
431}
432
433#[cfg(test)]
434impl<K, V> Default for Notify<K, V>
435where
436    K: Send + Hash + Eq + Clone + 'static,
437    V: Send + Sync + Clone + 'static,
438{
439    /// Useless notify mainly for test
440    fn default() -> Self {
441        Self::for_tests()
442    }
443}
444
445impl<K, V> Debug for Notify<K, V> {
446    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447        f.debug_struct("Notify").finish()
448    }
449}
450
451struct HandleGuard<K, V>
452where
453    K: Clone,
454{
455    topic: Arc<K>,
456    pubsub_sender: mpsc::Sender<Notification<K, V>>,
457}
458
459impl<K, V> Clone for HandleGuard<K, V>
460where
461    K: Clone,
462{
463    fn clone(&self) -> Self {
464        Self {
465            topic: self.topic.clone(),
466            pubsub_sender: self.pubsub_sender.clone(),
467        }
468    }
469}
470
471impl<K, V> Drop for HandleGuard<K, V>
472where
473    K: Clone,
474{
475    fn drop(&mut self) {
476        let err = self.pubsub_sender.try_send(Notification::Unsubscribe {
477            topic: self.topic.as_ref().clone(),
478        });
479        if let Err(err) = err {
480            tracing::trace!("cannot unsubscribe {err:?}");
481        }
482    }
483}
484
485pin_project! {
486pub struct Handle<K, V>
487where
488    K: Clone,
489{
490    handle_guard: HandleGuard<K, V>,
491    #[pin]
492    msg_sender: broadcast::Sender<Option<V>>,
493    #[pin]
494    msg_receiver: BroadcastStream<Option<V>>,
495}
496}
497
498impl<K, V> Clone for Handle<K, V>
499where
500    K: Clone,
501    V: Clone + Send + 'static,
502{
503    fn clone(&self) -> Self {
504        Self {
505            handle_guard: self.handle_guard.clone(),
506            msg_receiver: BroadcastStream::new(self.msg_sender.subscribe()),
507            msg_sender: self.msg_sender.clone(),
508        }
509    }
510}
511
512impl<K, V> Handle<K, V>
513where
514    K: Clone,
515{
516    fn new(
517        topic: K,
518        pubsub_sender: mpsc::Sender<Notification<K, V>>,
519        msg_sender: broadcast::Sender<Option<V>>,
520        msg_receiver: BroadcastStream<Option<V>>,
521    ) -> Self {
522        Self {
523            handle_guard: HandleGuard {
524                topic: Arc::new(topic),
525                pubsub_sender,
526            },
527            msg_sender,
528            msg_receiver,
529        }
530    }
531
532    pub(crate) fn into_stream(self) -> HandleStream<K, V> {
533        HandleStream {
534            handle_guard: self.handle_guard,
535            msg_receiver: self.msg_receiver,
536        }
537    }
538
539    pub(crate) fn into_sink(self) -> HandleSink<K, V> {
540        HandleSink {
541            handle_guard: self.handle_guard,
542            msg_sender: self.msg_sender,
543        }
544    }
545
546    /// Return a sink and a stream
547    pub fn split(self) -> (HandleSink<K, V>, HandleStream<K, V>) {
548        (
549            HandleSink {
550                handle_guard: self.handle_guard.clone(),
551                msg_sender: self.msg_sender,
552            },
553            HandleStream {
554                handle_guard: self.handle_guard,
555                msg_receiver: self.msg_receiver,
556            },
557        )
558    }
559}
560
561pin_project! {
562pub struct HandleStream<K, V>
563where
564    K: Clone,
565{
566    handle_guard: HandleGuard<K, V>,
567    #[pin]
568    msg_receiver: BroadcastStream<Option<V>>,
569}
570}
571
572impl<K, V> Stream for HandleStream<K, V>
573where
574    K: Clone,
575    V: Clone + 'static + Send,
576{
577    type Item = V;
578
579    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
580        let mut this = self.as_mut().project();
581
582        match Pin::new(&mut this.msg_receiver).poll_next(cx) {
583            Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(_)))) => {
584                u64_counter!(
585                    "apollo.router.skipped.event.count",
586                    "Amount of events dropped from the internal message queue",
587                    1u64
588                );
589                self.poll_next(cx)
590            }
591            Poll::Ready(None) => Poll::Ready(None),
592            Poll::Ready(Some(Ok(Some(val)))) => Poll::Ready(Some(val)),
593            Poll::Ready(Some(Ok(None))) => Poll::Ready(None),
594            Poll::Pending => Poll::Pending,
595        }
596    }
597}
598
599pin_project! {
600pub struct HandleSink<K, V>
601where
602    K: Clone,
603{
604    handle_guard: HandleGuard<K, V>,
605    #[pin]
606    msg_sender: broadcast::Sender<Option<V>>,
607}
608}
609
610impl<K, V> HandleSink<K, V>
611where
612    K: Clone,
613    V: Clone + 'static + Send,
614{
615    /// Send data to the subscribed topic
616    pub(crate) fn send_sync(&mut self, data: V) -> Result<(), NotifyError<K, V>> {
617        self.msg_sender.send(data.into()).map_err(|err| {
618            NotifyError::BroadcastSendError(broadcast::error::SendError(err.0.unwrap()))
619        })?;
620
621        Ok(())
622    }
623}
624
625impl<K, V> Sink<V> for HandleSink<K, V>
626where
627    K: Clone,
628    V: Clone + 'static + Send,
629{
630    type Error = graphql::Error;
631
632    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
633        Poll::Ready(Ok(()))
634    }
635
636    fn start_send(self: Pin<&mut Self>, item: V) -> Result<(), Self::Error> {
637        self.msg_sender.send(Some(item)).map_err(|_err| {
638            graphql::Error::builder()
639                .message("cannot send payload through pubsub")
640                .extension_code("NOTIFICATION_HANDLE_SEND_ERROR")
641                .build()
642        })?;
643        Ok(())
644    }
645
646    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
647        Poll::Ready(Ok(()))
648    }
649
650    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
651        let topic = self.handle_guard.topic.as_ref().clone();
652        let _ = self
653            .handle_guard
654            .pubsub_sender
655            .try_send(Notification::ForceDelete { topic });
656        Poll::Ready(Ok(()))
657    }
658}
659
660impl<K, V> Handle<K, V> where K: Clone {}
661
662async fn task<K, V>(
663    mut receiver: ReceiverStream<Notification<K, V>>,
664    mut ttl: Option<Duration>,
665    heartbeat_error_message: Option<V>,
666) where
667    K: Send + Hash + Eq + Clone + 'static,
668    V: Send + Clone + 'static,
669{
670    let mut pubsub: PubSub<K, V> = PubSub::new(ttl);
671
672    let mut ttl_fut: Box<dyn Stream<Item = tokio::time::Instant> + Send + Unpin> = match ttl {
673        Some(ttl) => Box::new(IntervalStream::new(tokio::time::interval(ttl))),
674        None => Box::new(tokio_stream::pending()),
675    };
676
677    loop {
678        tokio::select! {
679            _ = ttl_fut.next() => {
680                let heartbeat_error_message = heartbeat_error_message.clone();
681                pubsub.kill_dead_topics(heartbeat_error_message).await;
682            }
683            message = receiver.next() => {
684                match message {
685                    Some(message) => {
686                        match message {
687                            Notification::Unsubscribe { topic } => pubsub.unsubscribe(topic),
688                            Notification::ForceDelete { topic } => pubsub.force_delete(topic),
689                            Notification::CreateOrSubscribe { topic,  msg_sender, response_sender, heartbeat_enabled, operation_name } => {
690                                pubsub.subscribe_or_create(topic, msg_sender, response_sender, heartbeat_enabled, operation_name);
691                            }
692                            Notification::Subscribe {
693                                topic,
694                                response_sender,
695                            } => {
696                                pubsub.subscribe(topic, response_sender);
697                            }
698                            Notification::SubscribeIfExist {
699                                topic,
700                                response_sender,
701                            } => {
702                                if pubsub.is_used(&topic) {
703                                    pubsub.subscribe(topic, response_sender);
704                                } else {
705                                    pubsub.force_delete(topic);
706                                    let _ = response_sender.send(None);
707                                }
708                            }
709                            Notification::InvalidIds {
710                                topics,
711                                response_sender,
712                            } => {
713                                let invalid_topics = pubsub.invalid_topics(topics);
714                                let _ = response_sender.send(invalid_topics);
715                            }
716                            Notification::UpdateHeartbeat {
717                                mut new_ttl
718                            } => {
719                                // We accept to miss max 3 heartbeats before cutting the connection
720                                new_ttl = new_ttl.map(|ttl| ttl * 3);
721                                if ttl != new_ttl {
722                                    ttl = new_ttl;
723                                    pubsub.ttl = new_ttl;
724                                    match new_ttl {
725                                        Some(new_ttl) => {
726                                            ttl_fut = Box::new(IntervalStream::new(tokio::time::interval(new_ttl)));
727                                        },
728                                        None => {
729                                            ttl_fut = Box::new(tokio_stream::pending());
730                                        }
731                                    }
732                                }
733
734                            }
735                            Notification::Exist {
736                                topic,
737                                response_sender,
738                            } => {
739                                let exist = pubsub.exist(&topic);
740                                let _ = response_sender.send(exist);
741                                if exist {
742                                    pubsub.touch(&topic);
743                                }
744                            }
745                            #[cfg(test)]
746                            Notification::TryDelete { topic } => pubsub.try_delete(topic),
747                            #[cfg(test)]
748                            Notification::Broadcast { data } => {
749                                pubsub.broadcast(data).await;
750                            }
751                            #[cfg(test)]
752                            Notification::Debug { response_sender } => {
753                                let _ = response_sender.send(pubsub.subscriptions.len());
754                            }
755                        }
756                    },
757                    None => break,
758                }
759            }
760        }
761    }
762}
763
764#[derive(Debug)]
765struct Subscription<V> {
766    msg_sender: broadcast::Sender<Option<V>>,
767    closing_signal: broadcast::Sender<()>,
768    heartbeat_enabled: bool,
769    updated_at: Instant,
770    _metric_guard: UpDownCounterGuard<i64>,
771}
772
773impl<V> Subscription<V> {
774    fn new(
775        msg_sender: broadcast::Sender<Option<V>>,
776        closing_signal: broadcast::Sender<()>,
777        heartbeat_enabled: bool,
778        operation_name: Option<String>,
779    ) -> Self {
780        let metric_guard = i64_up_down_counter!(
781            "apollo.router.opened.subscriptions",
782            "Number of opened subscriptions",
783            1,
784            graphql.operation.name = operation_name.unwrap_or_default()
785        );
786
787        Self {
788            msg_sender,
789            closing_signal,
790            heartbeat_enabled,
791            updated_at: Instant::now(),
792            _metric_guard: metric_guard,
793        }
794    }
795    // Update the updated_at value
796    fn touch(&mut self) {
797        self.updated_at = Instant::now();
798    }
799
800    fn closing_signal(&self) -> broadcast::Receiver<()> {
801        self.closing_signal.subscribe()
802    }
803}
804
805struct PubSub<K, V>
806where
807    K: Hash + Eq,
808{
809    subscriptions: HashMap<K, Subscription<V>>,
810    ttl: Option<Duration>,
811}
812
813impl<K, V> Default for PubSub<K, V>
814where
815    K: Hash + Eq,
816{
817    fn default() -> Self {
818        Self {
819            // subscribers: HashMap::new(),
820            subscriptions: HashMap::new(),
821            ttl: None,
822        }
823    }
824}
825
826impl<K, V> PubSub<K, V>
827where
828    K: Hash + Eq + Clone,
829    V: Clone + 'static,
830{
831    fn new(ttl: Option<Duration>) -> Self {
832        Self {
833            subscriptions: HashMap::new(),
834            ttl,
835        }
836    }
837
838    fn create_topic(
839        &mut self,
840        topic: K,
841        sender: broadcast::Sender<Option<V>>,
842        heartbeat_enabled: bool,
843        operation_name: Option<String>,
844    ) -> broadcast::Receiver<()> {
845        let (closing_signal_tx, closing_signal_rx) = broadcast::channel(1);
846        self.subscriptions.insert(
847            topic,
848            Subscription::new(
849                sender,
850                closing_signal_tx,
851                heartbeat_enabled,
852                operation_name.clone(),
853            ),
854        );
855
856        closing_signal_rx
857    }
858
859    fn subscribe(&mut self, topic: K, sender: ResponseSender<V>) {
860        match self.subscriptions.get_mut(&topic) {
861            Some(subscription) => {
862                let _ = sender.send(Some((
863                    subscription.msg_sender.clone(),
864                    subscription.msg_sender.subscribe(),
865                )));
866            }
867            None => {
868                let _ = sender.send(None);
869            }
870        }
871    }
872
873    fn subscribe_or_create(
874        &mut self,
875        topic: K,
876        msg_sender: broadcast::Sender<Option<V>>,
877        sender: ResponseSenderWithCreated<V>,
878        heartbeat_enabled: bool,
879        operation_name: Option<String>,
880    ) {
881        match self.subscriptions.get(&topic) {
882            Some(subscription) => {
883                let _ = sender.send(CreatedTopicPayload {
884                    msg_sender: subscription.msg_sender.clone(),
885                    msg_receiver: subscription.msg_sender.subscribe(),
886                    closing_signal: subscription.closing_signal(),
887                    created: false,
888                });
889            }
890            None => {
891                let closing_signal =
892                    self.create_topic(topic, msg_sender.clone(), heartbeat_enabled, operation_name);
893
894                let _ = sender.send(CreatedTopicPayload {
895                    msg_sender: msg_sender.clone(),
896                    msg_receiver: msg_sender.subscribe(),
897                    closing_signal,
898                    created: true,
899                });
900            }
901        }
902    }
903
904    fn unsubscribe(&mut self, topic: K) {
905        let mut topic_to_delete = false;
906        match self.subscriptions.get(&topic) {
907            Some(subscription) => {
908                topic_to_delete = subscription.msg_sender.receiver_count() == 0;
909            }
910            None => tracing::trace!("Cannot find the subscription to unsubscribe"),
911        }
912        if topic_to_delete {
913            tracing::trace!("deleting subscription from unsubscribe");
914            self.force_delete(topic);
915        };
916    }
917
918    /// Check if the topic is used by anyone else than the current handle
919    fn is_used(&self, topic: &K) -> bool {
920        self.subscriptions
921            .get(topic)
922            .map(|s| s.msg_sender.receiver_count() > 0)
923            .unwrap_or_default()
924    }
925
926    /// Update the heartbeat
927    fn touch(&mut self, topic: &K) {
928        if let Some(sub) = self.subscriptions.get_mut(topic) {
929            sub.touch();
930        }
931    }
932
933    /// Check if the topic exists
934    fn exist(&self, topic: &K) -> bool {
935        self.subscriptions.contains_key(topic)
936    }
937
938    /// Given a list of topics, returns the list of valid and invalid topics
939    /// Heartbeat the given valid topics
940    fn invalid_topics(&mut self, topics: Vec<K>) -> (Vec<K>, Vec<K>) {
941        topics.into_iter().fold(
942            (Vec::new(), Vec::new()),
943            |(mut valid_ids, mut invalid_ids), e| {
944                match self.subscriptions.get_mut(&e) {
945                    Some(sub) => {
946                        sub.touch();
947                        valid_ids.push(e);
948                    }
949                    None => {
950                        invalid_ids.push(e);
951                    }
952                }
953
954                (valid_ids, invalid_ids)
955            },
956        )
957    }
958
959    /// clean all topics which didn't heartbeat
960    async fn kill_dead_topics(&mut self, heartbeat_error_message: Option<V>) {
961        if let Some(ttl) = self.ttl {
962            let drained = self.subscriptions.drain();
963            let (remaining_subs, closed_subs) = drained.into_iter().fold(
964                (HashMap::new(), HashMap::new()),
965                |(mut acc, mut acc_error), (topic, sub)| {
966                    if (!sub.heartbeat_enabled || sub.updated_at.elapsed() <= ttl)
967                        && sub.msg_sender.receiver_count() > 0
968                    {
969                        acc.insert(topic, sub);
970                    } else {
971                        acc_error.insert(topic, sub);
972                    }
973
974                    (acc, acc_error)
975                },
976            );
977            self.subscriptions = remaining_subs;
978
979            // Send error message to all killed connections
980            for (_, subscription) in closed_subs {
981                tracing::trace!("deleting subscription from kill_dead_topics");
982                self._force_delete(subscription, heartbeat_error_message.as_ref());
983            }
984        }
985    }
986
987    #[cfg(test)]
988    fn try_delete(&mut self, topic: K) {
989        if let Some(sub) = self.subscriptions.get(&topic)
990            && sub.msg_sender.receiver_count() > 1
991        {
992            return;
993        }
994
995        self.force_delete(topic);
996    }
997
998    fn force_delete(&mut self, topic: K) {
999        tracing::trace!("deleting subscription from force_delete");
1000        let sub = self.subscriptions.remove(&topic);
1001        if let Some(sub) = sub {
1002            self._force_delete(sub, None);
1003        }
1004    }
1005
1006    fn _force_delete(&mut self, sub: Subscription<V>, error_message: Option<&V>) {
1007        tracing::trace!("deleting subscription from _force_delete");
1008        if let Some(error_message) = error_message {
1009            let _ = sub.msg_sender.send(error_message.clone().into());
1010        }
1011        let _ = sub.msg_sender.send(None);
1012        let _ = sub.closing_signal.send(());
1013    }
1014
1015    #[cfg(test)]
1016    async fn broadcast(&mut self, value: V) -> Option<()>
1017    where
1018        V: Clone,
1019    {
1020        let mut fut = vec![];
1021        for (sub_id, sub) in &self.subscriptions {
1022            let cloned_value = value.clone();
1023            let sub_id = sub_id.clone();
1024            fut.push(
1025                sub.msg_sender
1026                    .send(cloned_value.into())
1027                    .is_err()
1028                    .then_some(sub_id),
1029            );
1030        }
1031        // clean closed sender
1032        let sub_to_clean: Vec<K> = fut.into_iter().flatten().collect();
1033        self.subscriptions
1034            .retain(|k, s| s.msg_sender.receiver_count() > 0 && !sub_to_clean.contains(k));
1035
1036        Some(())
1037    }
1038}
1039
1040pub(crate) struct RouterBroadcasts {
1041    configuration: (
1042        broadcast::Sender<Weak<Configuration>>,
1043        broadcast::Receiver<Weak<Configuration>>,
1044    ),
1045    schema: (
1046        broadcast::Sender<Arc<Schema>>,
1047        broadcast::Receiver<Arc<Schema>>,
1048    ),
1049}
1050
1051impl RouterBroadcasts {
1052    pub(crate) fn new() -> Self {
1053        Self {
1054            // Set to 2 to avoid potential deadlock when triggering a config/schema change mutiple times in a row
1055            configuration: broadcast::channel(2),
1056            schema: broadcast::channel(2),
1057        }
1058    }
1059
1060    pub(crate) fn subscribe_configuration(
1061        &self,
1062    ) -> impl Stream<Item = Weak<Configuration>> + use<> {
1063        BroadcastStream::new(self.configuration.0.subscribe())
1064            .filter_map(|cfg| futures::future::ready(cfg.ok()))
1065    }
1066
1067    pub(crate) fn subscribe_schema(&self) -> impl Stream<Item = Arc<Schema>> + use<> {
1068        BroadcastStream::new(self.schema.0.subscribe())
1069            .filter_map(|schema| futures::future::ready(schema.ok()))
1070    }
1071}
1072
1073#[cfg(test)]
1074mod tests {
1075
1076    use futures::FutureExt;
1077    use tokio_stream::StreamExt;
1078    use uuid::Uuid;
1079
1080    use super::*;
1081    use crate::metrics::FutureMetricsExt;
1082
1083    #[tokio::test]
1084    async fn subscribe() {
1085        let mut notify = Notify::builder().build();
1086        let topic_1 = Uuid::new_v4();
1087        let topic_2 = Uuid::new_v4();
1088
1089        let (handle1, created, mut subscription_closing_signal_1) = notify
1090            .create_or_subscribe(topic_1, false, None)
1091            .await
1092            .unwrap();
1093        assert!(created);
1094        let (_handle2, created, mut subscription_closing_signal_2) = notify
1095            .create_or_subscribe(topic_2, false, None)
1096            .await
1097            .unwrap();
1098        assert!(created);
1099
1100        let handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1101        let handle_1_other = notify.subscribe(topic_1).await.unwrap();
1102        let mut cloned_notify = notify.clone();
1103
1104        let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1105        handle
1106            .send_sync(serde_json_bytes::json!({"test": "ok"}))
1107            .unwrap();
1108        drop(handle);
1109        drop(handle1);
1110        let mut handle_1_bis = handle_1_bis.into_stream();
1111        let new_msg = handle_1_bis.next().await.unwrap();
1112        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1113        let mut handle_1_other = handle_1_other.into_stream();
1114        let new_msg = handle_1_other.next().await.unwrap();
1115        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1116
1117        assert!(notify.exist(topic_1).await.unwrap());
1118        assert!(notify.exist(topic_2).await.unwrap());
1119
1120        drop(_handle2);
1121        drop(handle_1_bis);
1122        drop(handle_1_other);
1123
1124        let subscriptions_nb = notify.debug().await.unwrap();
1125        assert_eq!(subscriptions_nb, 0);
1126
1127        subscription_closing_signal_1.try_recv().unwrap();
1128        subscription_closing_signal_2.try_recv().unwrap();
1129    }
1130
1131    #[tokio::test]
1132    async fn it_subscribe_and_delete() {
1133        let mut notify = Notify::builder().build();
1134        let topic_1 = Uuid::new_v4();
1135        let topic_2 = Uuid::new_v4();
1136
1137        let (handle1, created, mut subscription_closing_signal_1) = notify
1138            .create_or_subscribe(topic_1, true, None)
1139            .await
1140            .unwrap();
1141        assert!(created);
1142        let (_handle2, created, mut subscription_closing_signal_2) = notify
1143            .create_or_subscribe(topic_2, true, None)
1144            .await
1145            .unwrap();
1146        assert!(created);
1147
1148        let mut _handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1149        let mut _handle_1_other = notify.subscribe(topic_1).await.unwrap();
1150        let mut cloned_notify = notify.clone();
1151        let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1152        handle
1153            .send_sync(serde_json_bytes::json!({"test": "ok"}))
1154            .unwrap();
1155        drop(handle);
1156        assert!(notify.exist(topic_1).await.unwrap());
1157        drop(_handle_1_bis);
1158        drop(_handle_1_other);
1159
1160        notify.try_delete(topic_1).unwrap();
1161
1162        let subscriptions_nb = notify.debug().await.unwrap();
1163        assert_eq!(subscriptions_nb, 1);
1164
1165        assert!(!notify.exist(topic_1).await.unwrap());
1166
1167        notify.force_delete(topic_1).await.unwrap();
1168
1169        let mut handle1 = handle1.into_stream();
1170        let new_msg = handle1.next().await.unwrap();
1171        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1172        assert!(handle1.next().await.is_none());
1173        assert!(notify.exist(topic_2).await.unwrap());
1174        notify.try_delete(topic_2).unwrap();
1175
1176        let subscriptions_nb = notify.debug().await.unwrap();
1177        assert_eq!(subscriptions_nb, 0);
1178        drop(handle1);
1179        subscription_closing_signal_1.try_recv().unwrap();
1180        subscription_closing_signal_2.try_recv().unwrap();
1181    }
1182
1183    #[tokio::test]
1184    async fn it_subscribe_and_delete_metrics() {
1185        async {
1186            let mut notify = Notify::builder().build();
1187            let topic_1 = Uuid::new_v4();
1188            let topic_2 = Uuid::new_v4();
1189
1190            let (handle1, created, mut subscription_closing_signal_1) = notify
1191                .create_or_subscribe(topic_1, true, Some("TestSubscription".to_string()))
1192                .await
1193                .unwrap();
1194            assert!(created);
1195            let (_handle2, created, mut subscription_closing_signal_2) = notify
1196                .create_or_subscribe(topic_2, true, Some("TestSubscriptionBis".to_string()))
1197                .await
1198                .unwrap();
1199            assert!(created);
1200            assert_up_down_counter!(
1201                "apollo.router.opened.subscriptions",
1202                1i64,
1203                "graphql.operation.name" = "TestSubscription"
1204            );
1205            assert_up_down_counter!(
1206                "apollo.router.opened.subscriptions",
1207                1i64,
1208                "graphql.operation.name" = "TestSubscriptionBis"
1209            );
1210
1211            let mut _handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1212            let mut _handle_1_other = notify.subscribe(topic_1).await.unwrap();
1213            let mut cloned_notify = notify.clone();
1214            let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1215            handle
1216                .send_sync(serde_json_bytes::json!({"test": "ok"}))
1217                .unwrap();
1218            drop(handle);
1219            assert!(notify.exist(topic_1).await.unwrap());
1220            drop(_handle_1_bis);
1221            drop(_handle_1_other);
1222
1223            notify.try_delete(topic_1).unwrap();
1224            assert_up_down_counter!(
1225                "apollo.router.opened.subscriptions",
1226                1i64,
1227                "graphql.operation.name" = "TestSubscription"
1228            );
1229            assert_up_down_counter!(
1230                "apollo.router.opened.subscriptions",
1231                1i64,
1232                "graphql.operation.name" = "TestSubscriptionBis"
1233            );
1234
1235            let subscriptions_nb = notify.debug().await.unwrap();
1236            assert_eq!(subscriptions_nb, 1);
1237
1238            assert!(!notify.exist(topic_1).await.unwrap());
1239
1240            notify.force_delete(topic_1).await.unwrap();
1241            assert_up_down_counter!(
1242                "apollo.router.opened.subscriptions",
1243                0i64,
1244                "graphql.operation.name" = "TestSubscription"
1245            );
1246            assert_up_down_counter!(
1247                "apollo.router.opened.subscriptions",
1248                1i64,
1249                "graphql.operation.name" = "TestSubscriptionBis"
1250            );
1251
1252            let mut handle1 = handle1.into_stream();
1253            let new_msg = handle1.next().await.unwrap();
1254            assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1255            assert!(handle1.next().await.is_none());
1256            assert!(notify.exist(topic_2).await.unwrap());
1257            notify.try_delete(topic_2).unwrap();
1258
1259            let subscriptions_nb = notify.debug().await.unwrap();
1260            assert_eq!(subscriptions_nb, 0);
1261            assert_up_down_counter!(
1262                "apollo.router.opened.subscriptions",
1263                0i64,
1264                "graphql.operation.name" = "TestSubscription"
1265            );
1266            assert_up_down_counter!(
1267                "apollo.router.opened.subscriptions",
1268                0i64,
1269                "graphql.operation.name" = "TestSubscriptionBis"
1270            );
1271            subscription_closing_signal_1.try_recv().unwrap();
1272            subscription_closing_signal_2.try_recv().unwrap();
1273        }
1274        .with_metrics()
1275        .await;
1276    }
1277
1278    #[tokio::test]
1279    async fn it_test_ttl() {
1280        let mut notify = Notify::builder()
1281            .ttl(Duration::from_millis(300))
1282            .heartbeat_error_message(serde_json_bytes::json!({"error": "connection_closed"}))
1283            .build();
1284        let topic_1 = Uuid::new_v4();
1285        let topic_2 = Uuid::new_v4();
1286
1287        let (handle1, created, mut subscription_closing_signal_1) = notify
1288            .create_or_subscribe(topic_1, true, None)
1289            .await
1290            .unwrap();
1291        assert!(created);
1292        let (_handle2, created, mut subscription_closing_signal_2) = notify
1293            .create_or_subscribe(topic_2, true, None)
1294            .await
1295            .unwrap();
1296        assert!(created);
1297
1298        let handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1299        let handle_1_other = notify.subscribe(topic_1).await.unwrap();
1300        let mut cloned_notify = notify.clone();
1301        tokio::spawn(async move {
1302            let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1303            handle
1304                .send_sync(serde_json_bytes::json!({"test": "ok"}))
1305                .unwrap();
1306        });
1307        drop(handle1);
1308
1309        let mut handle_1_bis = handle_1_bis.into_stream();
1310        let new_msg = handle_1_bis.next().await.unwrap();
1311        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1312        let mut handle_1_other = handle_1_other.into_stream();
1313        let new_msg = handle_1_other.next().await.unwrap();
1314        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1315
1316        notify
1317            .set_ttl(Duration::from_millis(70).into())
1318            .await
1319            .unwrap();
1320
1321        tokio::time::sleep(Duration::from_millis(150)).await;
1322        let mut cloned_notify = notify.clone();
1323        tokio::spawn(async move {
1324            let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1325            handle
1326                .send_sync(serde_json_bytes::json!({"test": "ok"}))
1327                .unwrap();
1328        });
1329        let new_msg = handle_1_bis.next().await.unwrap();
1330        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1331        tokio::time::sleep(Duration::from_millis(150)).await;
1332
1333        let res = handle_1_bis.next().now_or_never().unwrap();
1334        assert_eq!(
1335            res,
1336            Some(serde_json_bytes::json!({"error": "connection_closed"}))
1337        );
1338
1339        assert!(handle_1_bis.next().await.is_none());
1340
1341        assert!(!notify.exist(topic_1).await.unwrap());
1342        assert!(!notify.exist(topic_2).await.unwrap());
1343        subscription_closing_signal_1.try_recv().unwrap();
1344        subscription_closing_signal_2.try_recv().unwrap();
1345
1346        let subscriptions_nb = notify.debug().await.unwrap();
1347        assert_eq!(subscriptions_nb, 0);
1348    }
1349}