Skip to main content

apollo_router/
notification.rs

1//! Internal pub/sub facility for subscription
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::hash::Hash;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::sync::Weak;
8use std::task::Context;
9use std::task::Poll;
10use std::time::Duration;
11use std::time::Instant;
12
13use futures::Sink;
14use futures::Stream;
15use futures::StreamExt;
16use pin_project_lite::pin_project;
17use thiserror::Error;
18use tokio::sync::broadcast;
19use tokio::sync::mpsc;
20use tokio::sync::mpsc::error::SendError;
21use tokio::sync::mpsc::error::TrySendError;
22use tokio::sync::oneshot;
23use tokio::sync::oneshot::error::RecvError;
24use tokio_stream::wrappers::BroadcastStream;
25use tokio_stream::wrappers::IntervalStream;
26use tokio_stream::wrappers::ReceiverStream;
27use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
28
29use crate::Configuration;
30use crate::graphql;
31use crate::metrics::FutureMetricsExt;
32use crate::spec::Schema;
33
34static NOTIFY_CHANNEL_SIZE: usize = 1024;
35static DEFAULT_MSG_CHANNEL_SIZE: usize = 128;
36
37#[derive(Error, Debug)]
38pub(crate) enum NotifyError<K, V> {
39    #[error("cannot receive data from pubsub")]
40    RecvError(#[from] RecvError),
41    #[error("cannot send data to pubsub")]
42    SendError(#[from] SendError<V>),
43    #[error("cannot send data to pubsub")]
44    NotificationSendError(#[from] SendError<Notification<K, V>>),
45    #[error("cannot send data to pubsub")]
46    NotificationTrySendError(#[from] TrySendError<Notification<K, V>>),
47    #[error("cannot send data to response stream")]
48    BroadcastSendError(#[from] broadcast::error::SendError<V>),
49    #[error("this topic doesn't exist")]
50    UnknownTopic,
51}
52
53type ResponseSender<V> =
54    oneshot::Sender<Option<(broadcast::Sender<Option<V>>, broadcast::Receiver<Option<V>>)>>;
55
56pub(crate) struct CreatedTopicPayload<V> {
57    msg_sender: broadcast::Sender<Option<V>>,
58    msg_receiver: broadcast::Receiver<Option<V>>,
59    closing_signal: broadcast::Receiver<()>,
60    created: bool,
61}
62
63type ResponseSenderWithCreated<V> = oneshot::Sender<CreatedTopicPayload<V>>;
64
65pub(crate) enum Notification<K, V> {
66    CreateOrSubscribe {
67        topic: K,
68        // Sender connected to the original source stream
69        msg_sender: broadcast::Sender<Option<V>>,
70        // To know if it has been created or re-used
71        response_sender: ResponseSenderWithCreated<V>,
72        heartbeat_enabled: bool,
73        // Useful for the metric we create
74        operation_name: Option<String>,
75    },
76    Subscribe {
77        topic: K,
78        // Oneshot channel to fetch the receiver
79        response_sender: ResponseSender<V>,
80    },
81    SubscribeIfExist {
82        topic: K,
83        // Oneshot channel to fetch the receiver
84        response_sender: ResponseSender<V>,
85    },
86    Unsubscribe {
87        topic: K,
88    },
89    ForceDelete {
90        topic: K,
91    },
92    Exist {
93        topic: K,
94        response_sender: oneshot::Sender<bool>,
95    },
96    InvalidIds {
97        topics: Vec<K>,
98        response_sender: oneshot::Sender<(Vec<K>, Vec<K>)>,
99    },
100    UpdateHeartbeat {
101        new_ttl: Option<Duration>,
102    },
103    #[cfg(test)]
104    TryDelete {
105        topic: K,
106    },
107    #[cfg(test)]
108    Broadcast {
109        data: V,
110    },
111    #[cfg(test)]
112    Debug {
113        // Returns the number of subscriptions and subscribers
114        response_sender: oneshot::Sender<usize>,
115    },
116}
117
118impl<K, V> Debug for Notification<K, V> {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        match self {
121            Self::CreateOrSubscribe { .. } => f.debug_struct("CreateOrSubscribe").finish(),
122            Self::Subscribe { .. } => f.debug_struct("Subscribe").finish(),
123            Self::SubscribeIfExist { .. } => f.debug_struct("SubscribeIfExist").finish(),
124            Self::Unsubscribe { .. } => f.debug_struct("Unsubscribe").finish(),
125            Self::ForceDelete { .. } => f.debug_struct("ForceDelete").finish(),
126            Self::Exist { .. } => f.debug_struct("Exist").finish(),
127            Self::InvalidIds { .. } => f.debug_struct("InvalidIds").finish(),
128            Self::UpdateHeartbeat { .. } => f.debug_struct("UpdateHeartbeat").finish(),
129            #[cfg(test)]
130            Self::TryDelete { .. } => f.debug_struct("TryDelete").finish(),
131            #[cfg(test)]
132            Self::Broadcast { .. } => f.debug_struct("Broadcast").finish(),
133            #[cfg(test)]
134            Self::Debug { .. } => f.debug_struct("Debug").finish(),
135        }
136    }
137}
138
139/// In memory pub/sub implementation
140#[derive(Clone)]
141pub struct Notify<K, V> {
142    sender: mpsc::Sender<Notification<K, V>>,
143    /// Size (number of events) of the channel to receive message
144    pub(crate) queue_size: Option<usize>,
145    router_broadcasts: Arc<RouterBroadcasts>,
146}
147
148#[buildstructor::buildstructor]
149impl<K, V> Notify<K, V>
150where
151    K: Send + Hash + Eq + Clone + 'static,
152    V: Send + Sync + Clone + 'static,
153{
154    #[builder]
155    pub(crate) fn new(
156        ttl: Option<Duration>,
157        heartbeat_error_message: Option<V>,
158        queue_size: Option<usize>,
159    ) -> Notify<K, V> {
160        let (sender, receiver) = mpsc::channel(NOTIFY_CHANNEL_SIZE);
161        let receiver_stream: ReceiverStream<Notification<K, V>> = ReceiverStream::new(receiver);
162        tokio::task::spawn(
163            task(receiver_stream, ttl, heartbeat_error_message).with_current_meter_provider(),
164        );
165        Notify {
166            sender,
167            queue_size,
168            router_broadcasts: Arc::new(RouterBroadcasts::new()),
169        }
170    }
171
172    #[doc(hidden)]
173    /// NOOP notifier for tests
174    pub fn for_tests() -> Self {
175        let (sender, _receiver) = mpsc::channel(NOTIFY_CHANNEL_SIZE);
176        Notify {
177            sender,
178            queue_size: None,
179            router_broadcasts: Arc::new(RouterBroadcasts::new()),
180        }
181    }
182}
183
184impl<K, V> Notify<K, V> {
185    /// Broadcast a new configuration
186    pub(crate) fn broadcast_configuration(&self, configuration: Weak<Configuration>) {
187        self.router_broadcasts.configuration.0.send(configuration).expect("cannot send the configuration update to the static channel. Should not happen because the receiver will always live in this struct; qed");
188    }
189    /// Receive the new configuration everytime we have a new router configuration
190    pub(crate) fn subscribe_configuration(&self) -> impl Stream<Item = Weak<Configuration>> {
191        self.router_broadcasts.subscribe_configuration()
192    }
193    /// Receive the new schema everytime we have a new schema
194    pub(crate) fn broadcast_schema(&self, schema: Arc<Schema>) {
195        self.router_broadcasts.schema.0.send(schema).expect("cannot send the schema update to the static channel. Should not happen because the receiver will always live in this struct; qed");
196    }
197    /// Receive the new schema everytime we have a new schema
198    pub(crate) fn subscribe_schema(&self) -> impl Stream<Item = Arc<Schema>> {
199        self.router_broadcasts.subscribe_schema()
200    }
201}
202
203impl<K, V> Notify<K, V>
204where
205    K: Send + Hash + Eq + Clone + 'static,
206    V: Send + Clone + 'static,
207{
208    pub(crate) async fn set_ttl(&self, new_ttl: Option<Duration>) -> Result<(), NotifyError<K, V>> {
209        self.sender
210            .send(Notification::UpdateHeartbeat { new_ttl })
211            .await?;
212
213        Ok(())
214    }
215
216    /// Creates or subscribes to a topic, returning a handle and subscription state.
217    ///
218    /// The `Ok()` branch of the `Result` is a tuple, where:
219    ///     - .0: a `Handle` on the subscription event listener,
220    ///     - .1: a boolean, where
221    ///              - `true`: call to this fn `created` this subscription, and
222    ///              - `false`: call to this fn was for a deduplicated subscription
223    ///                         i.e. subscription already exists,
224    ///     - .2: a closing signal in a form of `broadcast::Receiver` that gets
225    ///            triggered once the subscription is closed.
226    ///
227    /// # Closing Signal Usage
228    ///
229    /// The closing signal's usage depends on how subscriptions are managed:
230    ///
231    /// ## Callback Mode (HTTP-based subscriptions)
232    /// - The closing signal is typically **unused** as there are no long-running
233    ///   forwarding tasks to clean up
234    /// - Subscriptions are managed via HTTP callbacks to a public URL
235    /// - Subscription lifecycle is controlled through HTTP responses (404 closes the subscription)
236    /// - Always called with `heartbeat_enabled = true` to enable TTL-based timeout checking
237    ///
238    /// ## Passthrough Mode (WebSocket-based subscriptions)  
239    /// - The closing signal **must be monitored** by the forwarding task using `tokio::select!`
240    /// - Maintains persistent WebSocket connections to subgraphs
241    /// - Needed for proper cleanup when subscriptions are terminated, especially important
242    ///   for deduplication (multiple clients may share one subgraph connection)
243    /// - Always called with `heartbeat_enabled = false` as WebSockets have their own
244    ///   connection management
245    ///
246    /// # Parameters
247    /// - `topic`: The subscription topic identifier
248    /// - `heartbeat_enabled`: Controls TTL-based timeout checking at the notification layer:
249    ///   - `true`: Enables TTL checking. For callback mode, subscriptions will timeout if
250    ///     no heartbeat is received within the TTL period. The actual heartbeat interval
251    ///     is configured separately and sent to subgraphs in the subscription extension.
252    ///     When subgraphs send heartbeat messages, they're processed via `invalid_ids()`
253    ///     which calls `touch()` to update the subscription's `updated_at` timestamp.
254    ///   - `false`: Disables TTL checking (used by passthrough/WebSocket mode)
255    /// - `operation_name`: Optional GraphQL operation name for metrics
256    ///
257    /// # Heartbeat Processing for Callback Mode
258    ///
259    /// When callback mode is configured with a heartbeat interval:
260    /// 1. The interval is converted to milliseconds and sent to the subgraph as
261    ///    `heartbeat_interval_ms` in the subscription extension
262    /// 2. Subgraphs send periodic heartbeat callbacks with subscription IDs
263    /// 3. The heartbeat handler validates IDs and calls `notify.invalid_ids()`
264    /// 4. This updates each valid subscription's timestamp via `touch()`
265    /// 5. The TTL checker uses these timestamps to determine if subscriptions are alive
266    ///    and closes those that haven't been touched within the TTL period
267    pub(crate) async fn create_or_subscribe(
268        &mut self,
269        topic: K,
270        heartbeat_enabled: bool,
271        operation_name: Option<String>,
272    ) -> Result<(Handle<K, V>, bool, broadcast::Receiver<()>), NotifyError<K, V>> {
273        let (sender, _receiver) =
274            broadcast::channel(self.queue_size.unwrap_or(DEFAULT_MSG_CHANNEL_SIZE));
275
276        let (tx, rx) = oneshot::channel();
277        self.sender
278            .send(Notification::CreateOrSubscribe {
279                topic: topic.clone(),
280                msg_sender: sender,
281                response_sender: tx,
282                heartbeat_enabled,
283                operation_name,
284            })
285            .await?;
286
287        let CreatedTopicPayload {
288            msg_sender,
289            msg_receiver,
290            closing_signal,
291            created,
292        } = rx.await?;
293        let handle = Handle::new(
294            topic,
295            self.sender.clone(),
296            msg_sender,
297            BroadcastStream::from(msg_receiver),
298        );
299
300        Ok((handle, created, closing_signal))
301    }
302
303    pub(crate) async fn subscribe(&mut self, topic: K) -> Result<Handle<K, V>, NotifyError<K, V>> {
304        let (sender, receiver) = oneshot::channel();
305
306        self.sender
307            .send(Notification::Subscribe {
308                topic: topic.clone(),
309                response_sender: sender,
310            })
311            .await?;
312
313        let Some((msg_sender, msg_receiver)) = receiver.await? else {
314            return Err(NotifyError::UnknownTopic);
315        };
316        let handle = Handle::new(
317            topic,
318            self.sender.clone(),
319            msg_sender,
320            BroadcastStream::from(msg_receiver),
321        );
322
323        Ok(handle)
324    }
325
326    pub(crate) async fn subscribe_if_exist(
327        &mut self,
328        topic: K,
329    ) -> Result<Option<Handle<K, V>>, NotifyError<K, V>> {
330        let (sender, receiver) = oneshot::channel();
331
332        self.sender
333            .send(Notification::SubscribeIfExist {
334                topic: topic.clone(),
335                response_sender: sender,
336            })
337            .await?;
338
339        let Some((msg_sender, msg_receiver)) = receiver.await? else {
340            return Ok(None);
341        };
342        let handle = Handle::new(
343            topic,
344            self.sender.clone(),
345            msg_sender,
346            BroadcastStream::from(msg_receiver),
347        );
348
349        Ok(handle.into())
350    }
351
352    pub(crate) async fn exist(&mut self, topic: K) -> Result<bool, NotifyError<K, V>> {
353        // Channel to check if the topic still exists or not
354        let (response_tx, response_rx) = oneshot::channel();
355
356        self.sender
357            .send(Notification::Exist {
358                topic,
359                response_sender: response_tx,
360            })
361            .await?;
362
363        let resp = response_rx.await?;
364
365        Ok(resp)
366    }
367
368    pub(crate) async fn invalid_ids(
369        &mut self,
370        topics: Vec<K>,
371    ) -> Result<(Vec<K>, Vec<K>), NotifyError<K, V>> {
372        // Channel to check if the topic still exists or not
373        let (response_tx, response_rx) = oneshot::channel();
374
375        self.sender
376            .send(Notification::InvalidIds {
377                topics,
378                response_sender: response_tx,
379            })
380            .await?;
381
382        let resp = response_rx.await?;
383
384        Ok(resp)
385    }
386
387    /// Delete the topic even if several subscribers are still listening
388    pub(crate) async fn force_delete(&mut self, topic: K) -> Result<(), NotifyError<K, V>> {
389        // if disconnected, we don't care (the task was stopped)
390        self.sender
391            .send(Notification::ForceDelete { topic })
392            .await
393            .map_err(std::convert::Into::into)
394    }
395
396    /// Delete the topic if and only if one or zero subscriber is still listening
397    /// This function is not async to allow it to be used in a Drop impl
398    #[cfg(test)]
399    pub(crate) fn try_delete(&mut self, topic: K) -> Result<(), NotifyError<K, V>> {
400        // if disconnected, we don't care (the task was stopped)
401        self.sender
402            .try_send(Notification::TryDelete { topic })
403            .map_err(|try_send_error| try_send_error.into())
404    }
405
406    #[cfg(test)]
407    pub(crate) async fn broadcast(&mut self, data: V) -> Result<(), NotifyError<K, V>> {
408        self.sender
409            .send(Notification::Broadcast { data })
410            .await
411            .map_err(std::convert::Into::into)
412    }
413
414    #[cfg(test)]
415    pub(crate) async fn debug(&mut self) -> Result<usize, NotifyError<K, V>> {
416        let (response_tx, response_rx) = oneshot::channel();
417        self.sender
418            .send(Notification::Debug {
419                response_sender: response_tx,
420            })
421            .await?;
422
423        Ok(response_rx.await.unwrap())
424    }
425}
426
427#[cfg(test)]
428impl<K, V> Default for Notify<K, V>
429where
430    K: Send + Hash + Eq + Clone + 'static,
431    V: Send + Sync + Clone + 'static,
432{
433    /// Useless notify mainly for test
434    fn default() -> Self {
435        Self::for_tests()
436    }
437}
438
439impl<K, V> Debug for Notify<K, V> {
440    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441        f.debug_struct("Notify").finish()
442    }
443}
444
445struct HandleGuard<K, V>
446where
447    K: Clone,
448{
449    topic: Arc<K>,
450    pubsub_sender: mpsc::Sender<Notification<K, V>>,
451}
452
453impl<K, V> Clone for HandleGuard<K, V>
454where
455    K: Clone,
456{
457    fn clone(&self) -> Self {
458        Self {
459            topic: self.topic.clone(),
460            pubsub_sender: self.pubsub_sender.clone(),
461        }
462    }
463}
464
465impl<K, V> Drop for HandleGuard<K, V>
466where
467    K: Clone,
468{
469    fn drop(&mut self) {
470        let err = self.pubsub_sender.try_send(Notification::Unsubscribe {
471            topic: self.topic.as_ref().clone(),
472        });
473        if let Err(err) = err {
474            tracing::trace!("cannot unsubscribe {err:?}");
475        }
476    }
477}
478
479pin_project! {
480pub struct Handle<K, V>
481where
482    K: Clone,
483{
484    handle_guard: HandleGuard<K, V>,
485    #[pin]
486    msg_sender: broadcast::Sender<Option<V>>,
487    #[pin]
488    msg_receiver: BroadcastStream<Option<V>>,
489}
490}
491
492impl<K, V> Clone for Handle<K, V>
493where
494    K: Clone,
495    V: Clone + Send + 'static,
496{
497    fn clone(&self) -> Self {
498        Self {
499            handle_guard: self.handle_guard.clone(),
500            msg_receiver: BroadcastStream::new(self.msg_sender.subscribe()),
501            msg_sender: self.msg_sender.clone(),
502        }
503    }
504}
505
506impl<K, V> Handle<K, V>
507where
508    K: Clone,
509{
510    fn new(
511        topic: K,
512        pubsub_sender: mpsc::Sender<Notification<K, V>>,
513        msg_sender: broadcast::Sender<Option<V>>,
514        msg_receiver: BroadcastStream<Option<V>>,
515    ) -> Self {
516        Self {
517            handle_guard: HandleGuard {
518                topic: Arc::new(topic),
519                pubsub_sender,
520            },
521            msg_sender,
522            msg_receiver,
523        }
524    }
525
526    pub(crate) fn into_stream(self) -> HandleStream<K, V> {
527        HandleStream {
528            handle_guard: self.handle_guard,
529            msg_receiver: self.msg_receiver,
530        }
531    }
532
533    pub(crate) fn into_sink(self) -> HandleSink<K, V> {
534        HandleSink {
535            handle_guard: self.handle_guard,
536            msg_sender: self.msg_sender,
537        }
538    }
539
540    /// Return a sink and a stream
541    pub fn split(self) -> (HandleSink<K, V>, HandleStream<K, V>) {
542        (
543            HandleSink {
544                handle_guard: self.handle_guard.clone(),
545                msg_sender: self.msg_sender,
546            },
547            HandleStream {
548                handle_guard: self.handle_guard,
549                msg_receiver: self.msg_receiver,
550            },
551        )
552    }
553}
554
555pin_project! {
556pub struct HandleStream<K, V>
557where
558    K: Clone,
559{
560    handle_guard: HandleGuard<K, V>,
561    #[pin]
562    msg_receiver: BroadcastStream<Option<V>>,
563}
564}
565
566impl<K, V> Stream for HandleStream<K, V>
567where
568    K: Clone,
569    V: Clone + 'static + Send,
570{
571    type Item = V;
572
573    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
574        let mut this = self.as_mut().project();
575
576        match Pin::new(&mut this.msg_receiver).poll_next(cx) {
577            Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(_)))) => {
578                u64_counter!(
579                    "apollo_router_skipped_event_count",
580                    "Amount of events dropped from the internal message queue",
581                    1u64
582                );
583                self.poll_next(cx)
584            }
585            Poll::Ready(None) => Poll::Ready(None),
586            Poll::Ready(Some(Ok(Some(val)))) => Poll::Ready(Some(val)),
587            Poll::Ready(Some(Ok(None))) => Poll::Ready(None),
588            Poll::Pending => Poll::Pending,
589        }
590    }
591}
592
593pin_project! {
594pub struct HandleSink<K, V>
595where
596    K: Clone,
597{
598    handle_guard: HandleGuard<K, V>,
599    #[pin]
600    msg_sender: broadcast::Sender<Option<V>>,
601}
602}
603
604impl<K, V> HandleSink<K, V>
605where
606    K: Clone,
607    V: Clone + 'static + Send,
608{
609    /// Send data to the subscribed topic
610    pub(crate) fn send_sync(&mut self, data: V) -> Result<(), NotifyError<K, V>> {
611        self.msg_sender.send(data.into()).map_err(|err| {
612            NotifyError::BroadcastSendError(broadcast::error::SendError(err.0.unwrap()))
613        })?;
614
615        Ok(())
616    }
617}
618
619impl<K, V> Sink<V> for HandleSink<K, V>
620where
621    K: Clone,
622    V: Clone + 'static + Send,
623{
624    type Error = graphql::Error;
625
626    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
627        Poll::Ready(Ok(()))
628    }
629
630    fn start_send(self: Pin<&mut Self>, item: V) -> Result<(), Self::Error> {
631        self.msg_sender.send(Some(item)).map_err(|_err| {
632            graphql::Error::builder()
633                .message("cannot send payload through pubsub")
634                .extension_code("NOTIFICATION_HANDLE_SEND_ERROR")
635                .build()
636        })?;
637        Ok(())
638    }
639
640    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
641        Poll::Ready(Ok(()))
642    }
643
644    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
645        let topic = self.handle_guard.topic.as_ref().clone();
646        let _ = self
647            .handle_guard
648            .pubsub_sender
649            .try_send(Notification::ForceDelete { topic });
650        Poll::Ready(Ok(()))
651    }
652}
653
654impl<K, V> Handle<K, V> where K: Clone {}
655
656async fn task<K, V>(
657    mut receiver: ReceiverStream<Notification<K, V>>,
658    mut ttl: Option<Duration>,
659    heartbeat_error_message: Option<V>,
660) where
661    K: Send + Hash + Eq + Clone + 'static,
662    V: Send + Clone + 'static,
663{
664    let mut pubsub: PubSub<K, V> = PubSub::new(ttl);
665
666    let mut ttl_fut: Box<dyn Stream<Item = tokio::time::Instant> + Send + Unpin> = match ttl {
667        Some(ttl) => Box::new(IntervalStream::new(tokio::time::interval(ttl))),
668        None => Box::new(tokio_stream::pending()),
669    };
670
671    loop {
672        tokio::select! {
673            _ = ttl_fut.next() => {
674                let heartbeat_error_message = heartbeat_error_message.clone();
675                pubsub.kill_dead_topics(heartbeat_error_message).await;
676            }
677            message = receiver.next() => {
678                match message {
679                    Some(message) => {
680                        match message {
681                            Notification::Unsubscribe { topic } => pubsub.unsubscribe(topic),
682                            Notification::ForceDelete { topic } => pubsub.force_delete(topic),
683                            Notification::CreateOrSubscribe { topic,  msg_sender, response_sender, heartbeat_enabled, operation_name } => {
684                                pubsub.subscribe_or_create(topic, msg_sender, response_sender, heartbeat_enabled, operation_name);
685                            }
686                            Notification::Subscribe {
687                                topic,
688                                response_sender,
689                            } => {
690                                pubsub.subscribe(topic, response_sender);
691                            }
692                            Notification::SubscribeIfExist {
693                                topic,
694                                response_sender,
695                            } => {
696                                if pubsub.is_used(&topic) {
697                                    pubsub.subscribe(topic, response_sender);
698                                } else {
699                                    pubsub.force_delete(topic);
700                                    let _ = response_sender.send(None);
701                                }
702                            }
703                            Notification::InvalidIds {
704                                topics,
705                                response_sender,
706                            } => {
707                                let invalid_topics = pubsub.invalid_topics(topics);
708                                let _ = response_sender.send(invalid_topics);
709                            }
710                            Notification::UpdateHeartbeat {
711                                mut new_ttl
712                            } => {
713                                // We accept to miss max 3 heartbeats before cutting the connection
714                                new_ttl = new_ttl.map(|ttl| ttl * 3);
715                                if ttl != new_ttl {
716                                    ttl = new_ttl;
717                                    pubsub.ttl = new_ttl;
718                                    match new_ttl {
719                                        Some(new_ttl) => {
720                                            ttl_fut = Box::new(IntervalStream::new(tokio::time::interval(new_ttl)));
721                                        },
722                                        None => {
723                                            ttl_fut = Box::new(tokio_stream::pending());
724                                        }
725                                    }
726                                }
727
728                            }
729                            Notification::Exist {
730                                topic,
731                                response_sender,
732                            } => {
733                                let exist = pubsub.exist(&topic);
734                                let _ = response_sender.send(exist);
735                                if exist {
736                                    pubsub.touch(&topic);
737                                }
738                            }
739                            #[cfg(test)]
740                            Notification::TryDelete { topic } => pubsub.try_delete(topic),
741                            #[cfg(test)]
742                            Notification::Broadcast { data } => {
743                                pubsub.broadcast(data).await;
744                            }
745                            #[cfg(test)]
746                            Notification::Debug { response_sender } => {
747                                let _ = response_sender.send(pubsub.subscriptions.len());
748                            }
749                        }
750                    },
751                    None => break,
752                }
753            }
754        }
755    }
756}
757
758#[derive(Debug)]
759struct Subscription<V> {
760    msg_sender: broadcast::Sender<Option<V>>,
761    closing_signal: broadcast::Sender<()>,
762    heartbeat_enabled: bool,
763    updated_at: Instant,
764    operation_name: Option<String>,
765}
766
767impl<V> Subscription<V> {
768    fn new(
769        msg_sender: broadcast::Sender<Option<V>>,
770        closing_signal: broadcast::Sender<()>,
771        heartbeat_enabled: bool,
772        operation_name: Option<String>,
773    ) -> Self {
774        Self {
775            msg_sender,
776            closing_signal,
777            heartbeat_enabled,
778            updated_at: Instant::now(),
779            operation_name,
780        }
781    }
782    // Update the updated_at value
783    fn touch(&mut self) {
784        self.updated_at = Instant::now();
785    }
786
787    fn closing_signal(&self) -> broadcast::Receiver<()> {
788        self.closing_signal.subscribe()
789    }
790}
791
792struct PubSub<K, V>
793where
794    K: Hash + Eq,
795{
796    subscriptions: HashMap<K, Subscription<V>>,
797    ttl: Option<Duration>,
798}
799
800impl<K, V> Default for PubSub<K, V>
801where
802    K: Hash + Eq,
803{
804    fn default() -> Self {
805        Self {
806            // subscribers: HashMap::new(),
807            subscriptions: HashMap::new(),
808            ttl: None,
809        }
810    }
811}
812
813impl<K, V> PubSub<K, V>
814where
815    K: Hash + Eq + Clone,
816    V: Clone + 'static,
817{
818    fn new(ttl: Option<Duration>) -> Self {
819        Self {
820            subscriptions: HashMap::new(),
821            ttl,
822        }
823    }
824
825    fn create_topic(
826        &mut self,
827        topic: K,
828        sender: broadcast::Sender<Option<V>>,
829        heartbeat_enabled: bool,
830        operation_name: Option<String>,
831    ) -> broadcast::Receiver<()> {
832        let (closing_signal_tx, closing_signal_rx) = broadcast::channel(1);
833        let existed = self
834            .subscriptions
835            .insert(
836                topic,
837                Subscription::new(
838                    sender,
839                    closing_signal_tx,
840                    heartbeat_enabled,
841                    operation_name.clone(),
842                ),
843            )
844            .is_some();
845        if !existed {
846            // TODO: deprecated name, should use our new convention apollo.router. for router next
847            i64_up_down_counter!(
848                "apollo_router_opened_subscriptions",
849                "Number of opened subscriptions",
850                1,
851                graphql.operation.name = operation_name.unwrap_or_default()
852            );
853        }
854
855        closing_signal_rx
856    }
857
858    fn subscribe(&mut self, topic: K, sender: ResponseSender<V>) {
859        match self.subscriptions.get_mut(&topic) {
860            Some(subscription) => {
861                let _ = sender.send(Some((
862                    subscription.msg_sender.clone(),
863                    subscription.msg_sender.subscribe(),
864                )));
865            }
866            None => {
867                let _ = sender.send(None);
868            }
869        }
870    }
871
872    fn subscribe_or_create(
873        &mut self,
874        topic: K,
875        msg_sender: broadcast::Sender<Option<V>>,
876        sender: ResponseSenderWithCreated<V>,
877        heartbeat_enabled: bool,
878        operation_name: Option<String>,
879    ) {
880        match self.subscriptions.get(&topic) {
881            Some(subscription) => {
882                let _ = sender.send(CreatedTopicPayload {
883                    msg_sender: subscription.msg_sender.clone(),
884                    msg_receiver: subscription.msg_sender.subscribe(),
885                    closing_signal: subscription.closing_signal(),
886                    created: false,
887                });
888            }
889            None => {
890                let closing_signal =
891                    self.create_topic(topic, msg_sender.clone(), heartbeat_enabled, operation_name);
892
893                let _ = sender.send(CreatedTopicPayload {
894                    msg_sender: msg_sender.clone(),
895                    msg_receiver: msg_sender.subscribe(),
896                    closing_signal,
897                    created: true,
898                });
899            }
900        }
901    }
902
903    fn unsubscribe(&mut self, topic: K) {
904        let mut topic_to_delete = false;
905        match self.subscriptions.get(&topic) {
906            Some(subscription) => {
907                topic_to_delete = subscription.msg_sender.receiver_count() == 0;
908            }
909            None => tracing::trace!("Cannot find the subscription to unsubscribe"),
910        }
911        if topic_to_delete {
912            tracing::trace!("deleting subscription from unsubscribe");
913            self.force_delete(topic);
914        };
915    }
916
917    /// Check if the topic is used by anyone else than the current handle
918    fn is_used(&self, topic: &K) -> bool {
919        self.subscriptions
920            .get(topic)
921            .map(|s| s.msg_sender.receiver_count() > 0)
922            .unwrap_or_default()
923    }
924
925    /// Update the heartbeat
926    fn touch(&mut self, topic: &K) {
927        if let Some(sub) = self.subscriptions.get_mut(topic) {
928            sub.touch();
929        }
930    }
931
932    /// Check if the topic exists
933    fn exist(&self, topic: &K) -> bool {
934        self.subscriptions.contains_key(topic)
935    }
936
937    /// Given a list of topics, returns the list of valid and invalid topics
938    /// Heartbeat the given valid topics
939    fn invalid_topics(&mut self, topics: Vec<K>) -> (Vec<K>, Vec<K>) {
940        topics.into_iter().fold(
941            (Vec::new(), Vec::new()),
942            |(mut valid_ids, mut invalid_ids), e| {
943                match self.subscriptions.get_mut(&e) {
944                    Some(sub) => {
945                        sub.touch();
946                        valid_ids.push(e);
947                    }
948                    None => {
949                        invalid_ids.push(e);
950                    }
951                }
952
953                (valid_ids, invalid_ids)
954            },
955        )
956    }
957
958    /// clean all topics which didn't heartbeat
959    async fn kill_dead_topics(&mut self, heartbeat_error_message: Option<V>) {
960        if let Some(ttl) = self.ttl {
961            let drained = self.subscriptions.drain();
962            let (remaining_subs, closed_subs) = drained.into_iter().fold(
963                (HashMap::new(), HashMap::new()),
964                |(mut acc, mut acc_error), (topic, sub)| {
965                    if (!sub.heartbeat_enabled || sub.updated_at.elapsed() <= ttl)
966                        && sub.msg_sender.receiver_count() > 0
967                    {
968                        acc.insert(topic, sub);
969                    } else {
970                        acc_error.insert(topic, sub);
971                    }
972
973                    (acc, acc_error)
974                },
975            );
976            self.subscriptions = remaining_subs;
977
978            // Send error message to all killed connections
979            for (_, subscription) in closed_subs {
980                tracing::trace!("deleting subscription from kill_dead_topics");
981                self._force_delete(subscription, heartbeat_error_message.as_ref());
982            }
983        }
984    }
985
986    #[cfg(test)]
987    fn try_delete(&mut self, topic: K) {
988        if let Some(sub) = self.subscriptions.get(&topic) {
989            if sub.msg_sender.receiver_count() > 1 {
990                return;
991            }
992        }
993
994        self.force_delete(topic);
995    }
996
997    fn force_delete(&mut self, topic: K) {
998        tracing::trace!("deleting subscription from force_delete");
999        let sub = self.subscriptions.remove(&topic);
1000        if let Some(sub) = sub {
1001            self._force_delete(sub, None);
1002        }
1003    }
1004
1005    fn _force_delete(&mut self, sub: Subscription<V>, error_message: Option<&V>) {
1006        tracing::trace!("deleting subscription from _force_delete");
1007        i64_up_down_counter!(
1008            "apollo_router_opened_subscriptions",
1009            "Number of opened subscriptions",
1010            -1,
1011            graphql.operation.name = sub.operation_name.unwrap_or_default()
1012        );
1013        if let Some(error_message) = error_message {
1014            let _ = sub.msg_sender.send(error_message.clone().into());
1015        }
1016        let _ = sub.msg_sender.send(None);
1017        let _ = sub.closing_signal.send(());
1018    }
1019
1020    #[cfg(test)]
1021    async fn broadcast(&mut self, value: V) -> Option<()>
1022    where
1023        V: Clone,
1024    {
1025        let mut fut = vec![];
1026        for (sub_id, sub) in &self.subscriptions {
1027            let cloned_value = value.clone();
1028            let sub_id = sub_id.clone();
1029            fut.push(
1030                sub.msg_sender
1031                    .send(cloned_value.into())
1032                    .is_err()
1033                    .then_some(sub_id),
1034            );
1035        }
1036        // clean closed sender
1037        let sub_to_clean: Vec<K> = fut.into_iter().flatten().collect();
1038        self.subscriptions
1039            .retain(|k, s| s.msg_sender.receiver_count() > 0 && !sub_to_clean.contains(k));
1040
1041        Some(())
1042    }
1043}
1044
1045pub(crate) struct RouterBroadcasts {
1046    configuration: (
1047        broadcast::Sender<Weak<Configuration>>,
1048        broadcast::Receiver<Weak<Configuration>>,
1049    ),
1050    schema: (
1051        broadcast::Sender<Arc<Schema>>,
1052        broadcast::Receiver<Arc<Schema>>,
1053    ),
1054}
1055
1056impl RouterBroadcasts {
1057    pub(crate) fn new() -> Self {
1058        Self {
1059            // Set to 2 to avoid potential deadlock when triggering a config/schema change mutiple times in a row
1060            configuration: broadcast::channel(2),
1061            schema: broadcast::channel(2),
1062        }
1063    }
1064
1065    pub(crate) fn subscribe_configuration(&self) -> impl Stream<Item = Weak<Configuration>> {
1066        BroadcastStream::new(self.configuration.0.subscribe())
1067            .filter_map(|cfg| futures::future::ready(cfg.ok()))
1068    }
1069
1070    pub(crate) fn subscribe_schema(&self) -> impl Stream<Item = Arc<Schema>> {
1071        BroadcastStream::new(self.schema.0.subscribe())
1072            .filter_map(|schema| futures::future::ready(schema.ok()))
1073    }
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078
1079    use futures::FutureExt;
1080    use tokio_stream::StreamExt;
1081    use uuid::Uuid;
1082
1083    use super::*;
1084    use crate::metrics::FutureMetricsExt;
1085
1086    #[tokio::test]
1087    async fn subscribe() {
1088        let mut notify = Notify::builder().build();
1089        let topic_1 = Uuid::new_v4();
1090        let topic_2 = Uuid::new_v4();
1091
1092        let (handle1, created, mut subscription_closing_signal_1) = notify
1093            .create_or_subscribe(topic_1, false, None)
1094            .await
1095            .unwrap();
1096        assert!(created);
1097        let (_handle2, created, mut subscription_closing_signal_2) = notify
1098            .create_or_subscribe(topic_2, false, None)
1099            .await
1100            .unwrap();
1101        assert!(created);
1102
1103        let handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1104        let handle_1_other = notify.subscribe(topic_1).await.unwrap();
1105        let mut cloned_notify = notify.clone();
1106
1107        let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1108        handle
1109            .send_sync(serde_json_bytes::json!({"test": "ok"}))
1110            .unwrap();
1111        drop(handle);
1112        drop(handle1);
1113        let mut handle_1_bis = handle_1_bis.into_stream();
1114        let new_msg = handle_1_bis.next().await.unwrap();
1115        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1116        let mut handle_1_other = handle_1_other.into_stream();
1117        let new_msg = handle_1_other.next().await.unwrap();
1118        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1119
1120        assert!(notify.exist(topic_1).await.unwrap());
1121        assert!(notify.exist(topic_2).await.unwrap());
1122
1123        drop(_handle2);
1124        drop(handle_1_bis);
1125        drop(handle_1_other);
1126
1127        let subscriptions_nb = notify.debug().await.unwrap();
1128        assert_eq!(subscriptions_nb, 0);
1129
1130        subscription_closing_signal_1.try_recv().unwrap();
1131        subscription_closing_signal_2.try_recv().unwrap();
1132    }
1133
1134    #[tokio::test]
1135    async fn it_subscribe_and_delete() {
1136        let mut notify = Notify::builder().build();
1137        let topic_1 = Uuid::new_v4();
1138        let topic_2 = Uuid::new_v4();
1139
1140        let (handle1, created, mut subscription_closing_signal_1) = notify
1141            .create_or_subscribe(topic_1, true, None)
1142            .await
1143            .unwrap();
1144        assert!(created);
1145        let (_handle2, created, mut subscription_closing_signal_2) = notify
1146            .create_or_subscribe(topic_2, true, None)
1147            .await
1148            .unwrap();
1149        assert!(created);
1150
1151        let mut _handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1152        let mut _handle_1_other = notify.subscribe(topic_1).await.unwrap();
1153        let mut cloned_notify = notify.clone();
1154        let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1155        handle
1156            .send_sync(serde_json_bytes::json!({"test": "ok"}))
1157            .unwrap();
1158        drop(handle);
1159        assert!(notify.exist(topic_1).await.unwrap());
1160        drop(_handle_1_bis);
1161        drop(_handle_1_other);
1162
1163        notify.try_delete(topic_1).unwrap();
1164
1165        let subscriptions_nb = notify.debug().await.unwrap();
1166        assert_eq!(subscriptions_nb, 1);
1167
1168        assert!(!notify.exist(topic_1).await.unwrap());
1169
1170        notify.force_delete(topic_1).await.unwrap();
1171
1172        let mut handle1 = handle1.into_stream();
1173        let new_msg = handle1.next().await.unwrap();
1174        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1175        assert!(handle1.next().await.is_none());
1176        assert!(notify.exist(topic_2).await.unwrap());
1177        notify.try_delete(topic_2).unwrap();
1178
1179        let subscriptions_nb = notify.debug().await.unwrap();
1180        assert_eq!(subscriptions_nb, 0);
1181        drop(handle1);
1182        subscription_closing_signal_1.try_recv().unwrap();
1183        subscription_closing_signal_2.try_recv().unwrap();
1184    }
1185
1186    #[tokio::test]
1187    async fn it_subscribe_and_delete_metrics() {
1188        async {
1189            let mut notify = Notify::builder().build();
1190            let topic_1 = Uuid::new_v4();
1191            let topic_2 = Uuid::new_v4();
1192
1193            let (handle1, created, mut subscription_closing_signal_1) = notify
1194                .create_or_subscribe(topic_1, true, Some("TestSubscription".to_string()))
1195                .await
1196                .unwrap();
1197            assert!(created);
1198            let (_handle2, created, mut subscription_closing_signal_2) = notify
1199                .create_or_subscribe(topic_2, true, Some("TestSubscriptionBis".to_string()))
1200                .await
1201                .unwrap();
1202            assert!(created);
1203            assert_up_down_counter!(
1204                "apollo_router_opened_subscriptions",
1205                1i64,
1206                "graphql.operation.name" = "TestSubscription"
1207            );
1208            assert_up_down_counter!(
1209                "apollo_router_opened_subscriptions",
1210                1i64,
1211                "graphql.operation.name" = "TestSubscriptionBis"
1212            );
1213
1214            let mut _handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1215            let mut _handle_1_other = notify.subscribe(topic_1).await.unwrap();
1216            let mut cloned_notify = notify.clone();
1217            let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1218            handle
1219                .send_sync(serde_json_bytes::json!({"test": "ok"}))
1220                .unwrap();
1221            drop(handle);
1222            assert!(notify.exist(topic_1).await.unwrap());
1223            drop(_handle_1_bis);
1224            drop(_handle_1_other);
1225
1226            notify.try_delete(topic_1).unwrap();
1227            assert_up_down_counter!(
1228                "apollo_router_opened_subscriptions",
1229                1i64,
1230                "graphql.operation.name" = "TestSubscription"
1231            );
1232            assert_up_down_counter!(
1233                "apollo_router_opened_subscriptions",
1234                1i64,
1235                "graphql.operation.name" = "TestSubscriptionBis"
1236            );
1237
1238            let subscriptions_nb = notify.debug().await.unwrap();
1239            assert_eq!(subscriptions_nb, 1);
1240
1241            assert!(!notify.exist(topic_1).await.unwrap());
1242
1243            notify.force_delete(topic_1).await.unwrap();
1244            assert_up_down_counter!(
1245                "apollo_router_opened_subscriptions",
1246                0i64,
1247                "graphql.operation.name" = "TestSubscription"
1248            );
1249            assert_up_down_counter!(
1250                "apollo_router_opened_subscriptions",
1251                1i64,
1252                "graphql.operation.name" = "TestSubscriptionBis"
1253            );
1254
1255            let mut handle1 = handle1.into_stream();
1256            let new_msg = handle1.next().await.unwrap();
1257            assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1258            assert!(handle1.next().await.is_none());
1259            assert!(notify.exist(topic_2).await.unwrap());
1260            notify.try_delete(topic_2).unwrap();
1261
1262            let subscriptions_nb = notify.debug().await.unwrap();
1263            assert_eq!(subscriptions_nb, 0);
1264            assert_up_down_counter!(
1265                "apollo_router_opened_subscriptions",
1266                0i64,
1267                "graphql.operation.name" = "TestSubscription"
1268            );
1269            assert_up_down_counter!(
1270                "apollo_router_opened_subscriptions",
1271                0i64,
1272                "graphql.operation.name" = "TestSubscriptionBis"
1273            );
1274            subscription_closing_signal_1.try_recv().unwrap();
1275            subscription_closing_signal_2.try_recv().unwrap();
1276        }
1277        .with_metrics()
1278        .await;
1279    }
1280
1281    #[tokio::test]
1282    async fn it_test_ttl() {
1283        let mut notify = Notify::builder()
1284            .ttl(Duration::from_millis(300))
1285            .heartbeat_error_message(serde_json_bytes::json!({"error": "connection_closed"}))
1286            .build();
1287        let topic_1 = Uuid::new_v4();
1288        let topic_2 = Uuid::new_v4();
1289
1290        let (handle1, created, mut subscription_closing_signal_1) = notify
1291            .create_or_subscribe(topic_1, true, None)
1292            .await
1293            .unwrap();
1294        assert!(created);
1295        let (_handle2, created, mut subscription_closing_signal_2) = notify
1296            .create_or_subscribe(topic_2, true, None)
1297            .await
1298            .unwrap();
1299        assert!(created);
1300
1301        let handle_1_bis = notify.subscribe(topic_1).await.unwrap();
1302        let handle_1_other = notify.subscribe(topic_1).await.unwrap();
1303        let mut cloned_notify = notify.clone();
1304        tokio::spawn(async move {
1305            let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1306            handle
1307                .send_sync(serde_json_bytes::json!({"test": "ok"}))
1308                .unwrap();
1309        });
1310        drop(handle1);
1311
1312        let mut handle_1_bis = handle_1_bis.into_stream();
1313        let new_msg = handle_1_bis.next().await.unwrap();
1314        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1315        let mut handle_1_other = handle_1_other.into_stream();
1316        let new_msg = handle_1_other.next().await.unwrap();
1317        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1318
1319        notify
1320            .set_ttl(Duration::from_millis(70).into())
1321            .await
1322            .unwrap();
1323
1324        tokio::time::sleep(Duration::from_millis(150)).await;
1325        let mut cloned_notify = notify.clone();
1326        tokio::spawn(async move {
1327            let mut handle = cloned_notify.subscribe(topic_1).await.unwrap().into_sink();
1328            handle
1329                .send_sync(serde_json_bytes::json!({"test": "ok"}))
1330                .unwrap();
1331        });
1332        let new_msg = handle_1_bis.next().await.unwrap();
1333        assert_eq!(new_msg, serde_json_bytes::json!({"test": "ok"}));
1334        tokio::time::sleep(Duration::from_millis(150)).await;
1335
1336        let res = handle_1_bis.next().now_or_never().unwrap();
1337        assert_eq!(
1338            res,
1339            Some(serde_json_bytes::json!({"error": "connection_closed"}))
1340        );
1341
1342        assert!(handle_1_bis.next().await.is_none());
1343
1344        assert!(!notify.exist(topic_1).await.unwrap());
1345        assert!(!notify.exist(topic_2).await.unwrap());
1346        subscription_closing_signal_1.try_recv().unwrap();
1347        subscription_closing_signal_2.try_recv().unwrap();
1348
1349        let subscriptions_nb = notify.debug().await.unwrap();
1350        assert_eq!(subscriptions_nb, 0);
1351    }
1352}