apollo_router/uplink/
license_stream.rs

1// tonic does not derive `Eq` for the gRPC message types, which causes a warning from Clippy. The
2// current suggestion is to explicitly allow the lint in the module that imports the protos.
3// Read more: https://github.com/hyperium/tonic/issues/1056
4#![allow(clippy::derive_partial_eq_without_eq)]
5
6use std::collections::HashSet;
7use std::pin::Pin;
8use std::str::FromStr;
9use std::sync::Arc;
10use std::task::Context;
11use std::task::Poll;
12use std::time::Instant;
13use std::time::SystemTime;
14
15use futures::Stream;
16use futures::StreamExt;
17use futures::future::Ready;
18use futures::stream::FilterMap;
19use futures::stream::Fuse;
20use futures::stream::Repeat;
21use futures::stream::Zip;
22use graphql_client::GraphQLQuery;
23use pin_project_lite::pin_project;
24use strum::IntoEnumIterator;
25use tokio_util::time::DelayQueue;
26
27use super::license_enforcement::LicenseLimits;
28use super::license_enforcement::TpsLimit;
29use crate::AllowedFeature;
30use crate::router::Event;
31use crate::uplink::UplinkRequest;
32use crate::uplink::UplinkResponse;
33use crate::uplink::license_enforcement::Audience;
34use crate::uplink::license_enforcement::Claims;
35use crate::uplink::license_enforcement::License;
36use crate::uplink::license_enforcement::LicenseState;
37use crate::uplink::license_enforcement::OneOrMany;
38use crate::uplink::license_stream::license_query::FetchErrorCode;
39use crate::uplink::license_stream::license_query::LicenseQueryRouterEntitlements;
40
41const APOLLO_ROUTER_LICENSE_OFFLINE_UNSUPPORTED: &str = "APOLLO_ROUTER_LICENSE_OFFLINE_UNSUPPORTED";
42
43#[derive(GraphQLQuery)]
44#[graphql(
45    query_path = "src/uplink/license_query.graphql",
46    schema_path = "src/uplink/uplink.graphql",
47    request_derives = "Debug",
48    response_derives = "PartialEq, Debug, Deserialize",
49    deprecated = "warn"
50)]
51pub(crate) struct LicenseQuery {}
52
53impl From<UplinkRequest> for license_query::Variables {
54    fn from(req: UplinkRequest) -> Self {
55        license_query::Variables {
56            api_key: req.api_key,
57            graph_ref: req.graph_ref,
58            if_after_id: req.id,
59        }
60    }
61}
62
63impl From<license_query::ResponseData> for UplinkResponse<License> {
64    fn from(response: license_query::ResponseData) -> Self {
65        match response.router_entitlements {
66            LicenseQueryRouterEntitlements::RouterEntitlementsResult(result) => {
67                if let Some(license) = result.entitlement {
68                    match License::from_str(&license.jwt) {
69                        Ok(jwt) => UplinkResponse::New {
70                            response: jwt,
71                            id: result.id,
72                            // this will truncate the number of seconds to under u64::MAX, which should be
73                            // a large enough delay anyway
74                            delay: result.min_delay_seconds as u64,
75                        },
76                        Err(error) => UplinkResponse::Error {
77                            retry_later: true,
78                            code: "INVALID_LICENSE".to_string(),
79                            message: error.to_string(),
80                        },
81                    }
82                } else {
83                    UplinkResponse::New {
84                        response: License::default(),
85                        id: result.id,
86                        // this will truncate the number of seconds to under u64::MAX, which should be
87                        // a large enough delay anyway
88                        delay: result.min_delay_seconds as u64,
89                    }
90                }
91            }
92            LicenseQueryRouterEntitlements::Unchanged(response) => UplinkResponse::Unchanged {
93                id: Some(response.id),
94                delay: Some(response.min_delay_seconds as u64),
95            },
96            LicenseQueryRouterEntitlements::FetchError(error) => UplinkResponse::Error {
97                retry_later: error.code == FetchErrorCode::RETRY_LATER,
98                code: match error.code {
99                    FetchErrorCode::AUTHENTICATION_FAILED => "AUTHENTICATION_FAILED".to_string(),
100                    FetchErrorCode::ACCESS_DENIED => "ACCESS_DENIED".to_string(),
101                    FetchErrorCode::UNKNOWN_REF => "UNKNOWN_REF".to_string(),
102                    FetchErrorCode::RETRY_LATER => "RETRY_LATER".to_string(),
103                    FetchErrorCode::NOT_IMPLEMENTED_ON_THIS_INSTANCE => {
104                        "NOT_IMPLEMENTED_ON_THIS_INSTANCE".to_string()
105                    }
106                    FetchErrorCode::Other(other) => other,
107                },
108                message: error.message,
109            },
110        }
111    }
112}
113
114pin_project! {
115    /// This stream wrapper will cause check the current license at the point of warn_at or halt_at.
116    /// This means that the state machine can be kept clean, and not have to deal with setting it's own timers and also avoids lots of racy scenarios as license checks are guaranteed to happen after a license update even if they were in the past.
117    #[must_use = "streams do nothing unless polled"]
118    #[project = LicenseExpanderProj]
119    pub(crate) struct LicenseExpander<Upstream>
120    where
121        Upstream: Stream<Item = License>,
122    {
123        #[pin]
124        checks: DelayQueue<Event>,
125        #[pin]
126        upstream: Fuse<Upstream>,
127    }
128}
129
130impl<Upstream> Stream for LicenseExpander<Upstream>
131where
132    Upstream: Stream<Item = License>,
133{
134    type Item = Event;
135
136    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
137        let mut this = self.project();
138        let checks = this.checks.poll_expired(cx);
139        // Only check downstream if checks was not Some
140        let next = if matches!(checks, Poll::Ready(Some(_))) {
141            None
142        } else {
143            // Poll upstream. Note that it is OK for this to be called again after it has finished as the stream is fused and if it is exhausted it will return Poll::Ready(None).
144            Some(this.upstream.poll_next(cx))
145        };
146
147        match (checks, next) {
148            // Checks has an expired claim that needs checking.
149            // This is the ONLY arm where upstream.poll_next has not been called, and this is OK because we are not returning pending.
150            (Poll::Ready(Some(item)), _) => Poll::Ready(Some(item.into_inner())),
151            // Upstream has a new license with a claim
152            (_, Some(Poll::Ready(Some(license)))) if license.claims.is_some() => {
153                // If we got a new license then we need to reset the stream of events and return the new license event.
154                reset_checks_for_licenses(&mut this.checks, license)
155            }
156            // Upstream has a new license with no claim.
157            (_, Some(Poll::Ready(Some(_)))) => {
158                // We don't clear the checks if there is a license with no claim.
159                Poll::Ready(Some(Event::UpdateLicense(Arc::new(
160                    LicenseState::Unlicensed,
161                ))))
162            }
163            // If either checks or upstream returned pending then we need to return pending.
164            // It is the responsibility of upstream and checks to schedule wakeup.
165            // If we have got to this line then checks.poll_expired and upstream.poll_next *will* have been called.
166            (Poll::Pending, _) | (_, Some(Poll::Pending)) => Poll::Pending,
167            // If both stream are exhausted then return none.
168            (Poll::Ready(None), Some(Poll::Ready(None))) => Poll::Ready(None),
169            (Poll::Ready(None), None) => {
170                unreachable!("upstream will have been called as checks did not have a value")
171            }
172        }
173    }
174}
175
176/// This function takes a license and returns the appropriate event for that license.
177/// If warn at or halt at are in the future it will register appropriate checks to trigger at such times.
178fn reset_checks_for_licenses(
179    checks: &mut DelayQueue<Event>,
180    license: License,
181) -> Poll<Option<Event>> {
182    // We got a new claim, so clear the previous checks.
183    checks.clear();
184    let claims = license.claims.as_ref().expect("claims is gated, qed");
185
186    // Router limitations based on claims
187    let limits = match (claims.tps, &claims.allowed_features) {
188        (None, None) => None,
189        (Some(tps_limit), Some(features)) => Some(
190            LicenseLimits::builder()
191                .tps(
192                    TpsLimit::builder()
193                        .capacity(tps_limit.capacity)
194                        .interval(tps_limit.interval)
195                        .build(),
196                )
197                .allowed_features(HashSet::from_iter(features.clone()))
198                .build(),
199        ),
200        (Some(tps_limit), None) => Some(LicenseLimits {
201            tps: Some(TpsLimit {
202                capacity: tps_limit.capacity,
203                interval: tps_limit.interval,
204            }),
205            allowed_features: HashSet::from_iter(AllowedFeature::iter()),
206        }),
207        (None, Some(features)) => Some(
208            LicenseLimits::builder()
209                .allowed_features(HashSet::from_iter(features.clone()))
210                .build(),
211        ),
212    };
213
214    let halt_at = to_positive_instant(claims.halt_at);
215    let warn_at = to_positive_instant(claims.warn_at);
216    let now = Instant::now();
217    // Insert the new checks. If any of the boundaries are in the past then just return the immediate result
218    if halt_at > now {
219        // Only add halt if it isn't immediately going to be triggered.
220        checks.insert_at(
221            Event::UpdateLicense(Arc::new(LicenseState::LicensedHalt {
222                limits: limits.clone(),
223            })),
224            (halt_at).into(),
225        );
226    } else {
227        return Poll::Ready(Some(Event::UpdateLicense(Arc::new(
228            LicenseState::LicensedHalt {
229                limits: limits.clone(),
230            },
231        ))));
232    }
233    if warn_at > now {
234        // Only add warn if it isn't immediately going to be triggered and halt is not already set.
235        // Something that is halted is by definition also warn.
236        checks.insert_at(
237            Event::UpdateLicense(Arc::new(LicenseState::LicensedWarn {
238                limits: limits.clone(),
239            })),
240            (warn_at).into(),
241        );
242    } else {
243        return Poll::Ready(Some(Event::UpdateLicense(Arc::new(
244            LicenseState::LicensedWarn {
245                limits: limits.clone(),
246            },
247        ))));
248    }
249
250    Poll::Ready(Some(Event::UpdateLicense(Arc::new(
251        LicenseState::Licensed {
252            limits: limits.clone(),
253        },
254    ))))
255}
256
257/// This function exists to generate an approximate Instant from a `SystemTime`. We have externally generated unix timestamps that need to be scheduled, but anything time related to scheduling must be an `Instant`.
258/// The generated instant is only approximate.
259/// Subtracting from instants is not supported on all platforms, so if the calculated instant was in the past we just return now as we don't care about how long ago the instant was, just that it happened already.
260fn to_positive_instant(system_time: SystemTime) -> Instant {
261    // This is approximate as there is no real conversion between SystemTime and Instant
262    let now_instant = Instant::now();
263    let now_system_time = SystemTime::now();
264
265    // system_time is likely to be a time in the future, but may be in the past.
266    match system_time.duration_since(now_system_time) {
267        // system_time was in the future.
268        Ok(duration) => now_instant + duration,
269
270        // system_time was in the past.
271        Err(_) => now_instant,
272    }
273}
274
275type ValidateAudience<T> = FilterMap<
276    Zip<T, Repeat<Arc<HashSet<Audience>>>>,
277    Ready<Option<License>>,
278    fn((License, Arc<HashSet<Audience>>)) -> Ready<Option<License>>,
279>;
280
281pub(crate) trait LicenseStreamExt: Stream<Item = License> {
282    fn expand_licenses(self) -> LicenseExpander<Self>
283    where
284        Self: Sized,
285    {
286        LicenseExpander {
287            checks: Default::default(),
288            upstream: self.fuse(),
289        }
290    }
291
292    fn validate_audience(self, audiences: impl Into<HashSet<Audience>>) -> ValidateAudience<Self>
293    where
294        Self: Sized,
295    {
296        // Zip is used to inject the data into the stream, and then filter_map can be used to actually deal with the data.
297        // There's no way to do this with a closure without hitting compiler issues.
298        // In the past we have implemented our own steps where we have needed to inject state, but this is the recommended way to do it.
299        let audiences: Arc<HashSet<Audience>> = Arc::new(audiences.into());
300        self.zip(futures::stream::repeat(audiences))
301            .filter_map(|(license, audiences)| {
302                let matches = match &license {
303                    License {
304                        claims:
305                            Some(Claims {
306                                aud: OneOrMany::Many(aud),
307                                ..
308                            }),
309                    } => aud.iter().any(|aud| audiences.contains(aud)),
310                    License {
311                        claims:
312                            Some(Claims {
313                                aud: OneOrMany::One(aud),
314                                ..
315                            }),
316                    } => audiences.contains(aud),
317                    // A license with no claims is always valid. We will check later if any commercial features are in use.
318                    License { claims: None } => true,
319                };
320
321                if !matches {
322                    tracing::error!(
323                        code = APOLLO_ROUTER_LICENSE_OFFLINE_UNSUPPORTED,
324                        "the license file was valid, but was not enabled offline use",
325                    );
326                }
327                futures::future::ready(if matches { Some(license) } else { None })
328            })
329    }
330}
331
332impl<T: Stream<Item = License>> LicenseStreamExt for T {}
333
334#[cfg(test)]
335mod test {
336    use std::future::ready;
337    use std::time::Duration;
338    use std::time::Instant;
339    use std::time::SystemTime;
340
341    use futures::StreamExt;
342    use futures_test::stream::StreamTestExt;
343    use tracing::instrument::WithSubscriber;
344
345    use crate::assert_snapshot_subscriber;
346    use crate::router::Event;
347    use crate::uplink::UplinkConfig;
348    use crate::uplink::license_enforcement::Audience;
349    use crate::uplink::license_enforcement::Claims;
350    use crate::uplink::license_enforcement::License;
351    use crate::uplink::license_enforcement::LicenseState;
352    use crate::uplink::license_enforcement::OneOrMany;
353    use crate::uplink::license_stream::LicenseQuery;
354    use crate::uplink::license_stream::LicenseStreamExt;
355    use crate::uplink::license_stream::to_positive_instant;
356    use crate::uplink::stream_from_uplink;
357
358    #[tokio::test]
359    async fn integration_test() {
360        if let (Ok(apollo_key), Ok(apollo_graph_ref)) = (
361            std::env::var("TEST_APOLLO_KEY"),
362            std::env::var("TEST_APOLLO_GRAPH_REF"),
363        ) {
364            let results = stream_from_uplink::<LicenseQuery, License>(UplinkConfig {
365                apollo_key,
366                apollo_graph_ref,
367                endpoints: None,
368                poll_interval: Duration::from_secs(1),
369                timeout: Duration::from_secs(5),
370            })
371            .take(1)
372            .collect::<Vec<_>>()
373            .await;
374
375            assert!(
376                results
377                    .first()
378                    .expect("expected one result")
379                    .as_ref()
380                    .expect("license should be OK")
381                    .claims
382                    .is_some()
383            )
384        }
385    }
386
387    #[test]
388    fn test_to_instant() {
389        let now_system_time = SystemTime::now();
390        let now_instant = Instant::now();
391        let future_system_time = now_system_time + Duration::from_secs(1024);
392        let future_instant = to_positive_instant(future_system_time);
393        assert!(future_instant < now_instant + Duration::from_secs(1025));
394        assert!(future_instant > now_instant + Duration::from_secs(1023));
395
396        // An instant in the past will return something greater than the original now_instant, but less than a new instant.
397        let past_system_time = now_system_time - Duration::from_secs(1024);
398        let past_instant = to_positive_instant(past_system_time);
399        assert!(past_instant > now_instant);
400        assert!(past_instant < Instant::now());
401    }
402
403    #[tokio::test]
404    async fn license_expander() {
405        let events_stream = futures::stream::iter(vec![license_with_claim(15, 30)])
406            .expand_licenses()
407            .map(SimpleEvent::from);
408
409        let events = events_stream.collect::<Vec<_>>().await;
410        assert_eq!(
411            events,
412            &[
413                SimpleEvent::UpdateLicense,
414                SimpleEvent::WarnLicense,
415                SimpleEvent::HaltLicense
416            ]
417        );
418    }
419
420    #[tokio::test]
421    async fn license_expander_warn_now() {
422        let events_stream = futures::stream::iter(vec![license_with_claim(0, 15)])
423            .interleave_pending()
424            .expand_licenses()
425            .map(SimpleEvent::from);
426
427        let events = events_stream.collect::<Vec<_>>().await;
428        assert_eq!(
429            events,
430            &[SimpleEvent::WarnLicense, SimpleEvent::HaltLicense]
431        );
432    }
433
434    #[tokio::test]
435    async fn license_expander_halt_now() {
436        let events_stream = futures::stream::iter(vec![license_with_claim(0, 0)])
437            .interleave_pending()
438            .expand_licenses()
439            .map(SimpleEvent::from);
440
441        let events = events_stream.collect::<Vec<_>>().await;
442        assert_eq!(events, &[SimpleEvent::HaltLicense]);
443    }
444
445    #[tokio::test]
446    async fn license_expander_no_claim() {
447        let events_stream = futures::stream::iter(vec![license_with_no_claim()])
448            .interleave_pending()
449            .expand_licenses()
450            .map(SimpleEvent::from);
451
452        let events = events_stream.collect::<Vec<_>>().await;
453        assert_eq!(events, &[SimpleEvent::UpdateLicense]);
454    }
455
456    #[tokio::test]
457    async fn license_expander_claim_no_claim() {
458        // Licenses with no claim do not clear checks as they are ignored if we move from entitled to unentitled, this is handled at the state machine level.
459        let events_stream =
460            futures::stream::iter(vec![license_with_claim(10, 10), license_with_no_claim()])
461                .interleave_pending()
462                .expand_licenses()
463                .map(SimpleEvent::from);
464
465        let events = events_stream.collect::<Vec<_>>().await;
466        assert_eq!(
467            events,
468            &[
469                SimpleEvent::UpdateLicense,
470                SimpleEvent::UpdateLicense,
471                SimpleEvent::WarnLicense,
472                SimpleEvent::HaltLicense
473            ]
474        );
475    }
476
477    #[tokio::test]
478    async fn license_expander_no_claim_claim() {
479        let events_stream =
480            futures::stream::iter(vec![license_with_no_claim(), license_with_claim(15, 30)])
481                .interleave_pending()
482                .expand_licenses()
483                .map(SimpleEvent::from);
484
485        let events = events_stream.collect::<Vec<_>>().await;
486        assert_eq!(
487            events,
488            &[
489                SimpleEvent::UpdateLicense,
490                SimpleEvent::UpdateLicense,
491                SimpleEvent::WarnLicense,
492                SimpleEvent::HaltLicense
493            ]
494        );
495    }
496
497    #[tokio::test(flavor = "multi_thread")]
498    async fn license_expander_claim_pause_claim() {
499        let (tx, rx) = tokio::sync::mpsc::channel(10);
500        let rx_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
501        let events_stream = rx_stream.expand_licenses().map(SimpleEvent::from);
502
503        tokio::task::spawn(async move {
504            // This simulates a new claim coming in before in between the warning and halt
505            let _ = tx.send(license_with_claim(100, 300)).await;
506            tokio::time::sleep(Duration::from_millis(200)).await;
507            let _ = tx.send(license_with_claim(100, 300)).await;
508        });
509        let events = events_stream.collect::<Vec<_>>().await;
510        assert_eq!(
511            events,
512            &[
513                SimpleEvent::UpdateLicense,
514                SimpleEvent::WarnLicense,
515                SimpleEvent::UpdateLicense,
516                SimpleEvent::WarnLicense,
517                SimpleEvent::HaltLicense
518            ]
519        );
520    }
521
522    fn license_with_claim(warn_delta: u64, halt_delta: u64) -> License {
523        let now = SystemTime::now();
524        License {
525            claims: Some(Claims {
526                iss: "".to_string(),
527                sub: "".to_string(),
528                aud: OneOrMany::One(Audience::SelfHosted),
529                warn_at: now + Duration::from_millis(warn_delta),
530                halt_at: now + Duration::from_millis(halt_delta),
531                tps: Default::default(),
532                allowed_features: Default::default(),
533            }),
534        }
535    }
536
537    fn license_with_no_claim() -> License {
538        License { claims: None }
539    }
540
541    #[derive(Eq, PartialEq, Debug)]
542    enum SimpleEvent {
543        UpdateConfiguration,
544        NoMoreConfiguration,
545        UpdateSchema,
546        NoMoreSchema,
547        UpdateLicense,
548        HaltLicense,
549        WarnLicense,
550        NoMoreLicense,
551        ForcedHotReload,
552        Shutdown,
553    }
554
555    impl From<Event> for SimpleEvent {
556        fn from(value: Event) -> Self {
557            match value {
558                Event::UpdateConfiguration(_) => SimpleEvent::UpdateConfiguration,
559                Event::NoMoreConfiguration => SimpleEvent::NoMoreConfiguration,
560                Event::UpdateSchema(_) => SimpleEvent::UpdateSchema,
561                Event::NoMoreSchema => SimpleEvent::NoMoreSchema,
562                Event::UpdateLicense(license) => match *license {
563                    LicenseState::LicensedHalt { limits: _ } => SimpleEvent::HaltLicense,
564                    LicenseState::LicensedWarn { limits: _ } => SimpleEvent::WarnLicense,
565                    _ => SimpleEvent::UpdateLicense,
566                },
567                Event::NoMoreLicense => SimpleEvent::NoMoreLicense,
568                Event::Reload | Event::RhaiReload => SimpleEvent::ForcedHotReload,
569                Event::Shutdown => SimpleEvent::Shutdown,
570            }
571        }
572    }
573
574    #[tokio::test]
575    async fn test_validate_audience_single() {
576        assert_eq!(
577            futures::stream::once(ready(License {
578                claims: Some(Claims {
579                    iss: "".to_string(),
580                    sub: "".to_string(),
581                    aud: OneOrMany::One(Audience::Offline),
582                    warn_at: SystemTime::now(),
583                    halt_at: SystemTime::now(),
584                    tps: Default::default(),
585                    allowed_features: Default::default(),
586                }),
587            }))
588            .validate_audience([Audience::Offline, Audience::Cloud])
589            .count()
590            .with_subscriber(assert_snapshot_subscriber!())
591            .await,
592            1
593        );
594    }
595
596    #[tokio::test]
597    async fn test_validate_audience_single_filtered() {
598        assert_eq!(
599            futures::stream::once(ready(License {
600                claims: Some(Claims {
601                    iss: "".to_string(),
602                    sub: "".to_string(),
603                    aud: OneOrMany::One(Audience::SelfHosted),
604                    warn_at: SystemTime::now(),
605                    halt_at: SystemTime::now(),
606                    tps: Default::default(),
607                    allowed_features: Default::default(),
608                }),
609            }))
610            .validate_audience([Audience::Offline, Audience::Cloud])
611            .count()
612            .with_subscriber(assert_snapshot_subscriber!())
613            .await,
614            0
615        );
616    }
617
618    #[tokio::test]
619    async fn test_validate_audience_multiple() {
620        assert_eq!(
621            futures::stream::once(ready(License {
622                claims: Some(Claims {
623                    iss: "".to_string(),
624                    sub: "".to_string(),
625                    aud: OneOrMany::Many(vec![Audience::SelfHosted, Audience::Offline]),
626                    warn_at: SystemTime::now(),
627                    halt_at: SystemTime::now(),
628                    tps: Default::default(),
629                    allowed_features: Default::default(),
630                }),
631            }))
632            .validate_audience([Audience::Offline, Audience::Cloud])
633            .count()
634            .with_subscriber(assert_snapshot_subscriber!())
635            .await,
636            1
637        );
638    }
639
640    #[tokio::test]
641    async fn test_validate_audience_multiple_filtered() {
642        assert_eq!(
643            futures::stream::once(ready(License {
644                claims: Some(Claims {
645                    iss: "".to_string(),
646                    sub: "".to_string(),
647                    aud: OneOrMany::Many(vec![Audience::SelfHosted, Audience::SelfHosted]),
648                    warn_at: SystemTime::now(),
649                    halt_at: SystemTime::now(),
650                    tps: Default::default(),
651                    allowed_features: Default::default(),
652                }),
653            }))
654            .validate_audience([Audience::Offline, Audience::Cloud])
655            .count()
656            .with_subscriber(assert_snapshot_subscriber!())
657            .await,
658            0
659        );
660    }
661
662    #[tokio::test]
663    async fn test_validate_no_claim() {
664        assert_eq!(
665            futures::stream::once(ready(License::default()))
666                .validate_audience([Audience::Offline, Audience::Cloud])
667                .count()
668                .with_subscriber(assert_snapshot_subscriber!())
669                .await,
670            1
671        );
672    }
673}