1#![allow(clippy::derive_partial_eq_without_eq)]
5
6use std::collections::HashSet;
7use std::pin::Pin;
8use std::str::FromStr;
9use std::sync::Arc;
10use std::task::Context;
11use std::task::Poll;
12use std::time::Instant;
13use std::time::SystemTime;
14
15use futures::Stream;
16use futures::StreamExt;
17use futures::future::Ready;
18use futures::stream::FilterMap;
19use futures::stream::Fuse;
20use futures::stream::Repeat;
21use futures::stream::Zip;
22use graphql_client::GraphQLQuery;
23use pin_project_lite::pin_project;
24use strum::IntoEnumIterator;
25use tokio_util::time::DelayQueue;
26
27use super::license_enforcement::LicenseLimits;
28use super::license_enforcement::TpsLimit;
29use crate::AllowedFeature;
30use crate::router::Event;
31use crate::uplink::UplinkRequest;
32use crate::uplink::UplinkResponse;
33use crate::uplink::license_enforcement::Audience;
34use crate::uplink::license_enforcement::Claims;
35use crate::uplink::license_enforcement::License;
36use crate::uplink::license_enforcement::LicenseState;
37use crate::uplink::license_enforcement::OneOrMany;
38use crate::uplink::license_stream::license_query::FetchErrorCode;
39use crate::uplink::license_stream::license_query::LicenseQueryRouterEntitlements;
40
41const APOLLO_ROUTER_LICENSE_OFFLINE_UNSUPPORTED: &str = "APOLLO_ROUTER_LICENSE_OFFLINE_UNSUPPORTED";
42
43#[derive(GraphQLQuery)]
44#[graphql(
45 query_path = "src/uplink/license_query.graphql",
46 schema_path = "src/uplink/uplink.graphql",
47 request_derives = "Debug",
48 response_derives = "PartialEq, Debug, Deserialize",
49 deprecated = "warn"
50)]
51pub(crate) struct LicenseQuery {}
52
53impl From<UplinkRequest> for license_query::Variables {
54 fn from(req: UplinkRequest) -> Self {
55 license_query::Variables {
56 api_key: req.api_key,
57 graph_ref: req.graph_ref,
58 if_after_id: req.id,
59 }
60 }
61}
62
63impl From<license_query::ResponseData> for UplinkResponse<License> {
64 fn from(response: license_query::ResponseData) -> Self {
65 match response.router_entitlements {
66 LicenseQueryRouterEntitlements::RouterEntitlementsResult(result) => {
67 if let Some(license) = result.entitlement {
68 match License::from_str(&license.jwt) {
69 Ok(jwt) => UplinkResponse::New {
70 response: jwt,
71 id: result.id,
72 delay: result.min_delay_seconds as u64,
75 },
76 Err(error) => UplinkResponse::Error {
77 retry_later: true,
78 code: "INVALID_LICENSE".to_string(),
79 message: error.to_string(),
80 },
81 }
82 } else {
83 UplinkResponse::New {
84 response: License::default(),
85 id: result.id,
86 delay: result.min_delay_seconds as u64,
89 }
90 }
91 }
92 LicenseQueryRouterEntitlements::Unchanged(response) => UplinkResponse::Unchanged {
93 id: Some(response.id),
94 delay: Some(response.min_delay_seconds as u64),
95 },
96 LicenseQueryRouterEntitlements::FetchError(error) => UplinkResponse::Error {
97 retry_later: error.code == FetchErrorCode::RETRY_LATER,
98 code: match error.code {
99 FetchErrorCode::AUTHENTICATION_FAILED => "AUTHENTICATION_FAILED".to_string(),
100 FetchErrorCode::ACCESS_DENIED => "ACCESS_DENIED".to_string(),
101 FetchErrorCode::UNKNOWN_REF => "UNKNOWN_REF".to_string(),
102 FetchErrorCode::RETRY_LATER => "RETRY_LATER".to_string(),
103 FetchErrorCode::NOT_IMPLEMENTED_ON_THIS_INSTANCE => {
104 "NOT_IMPLEMENTED_ON_THIS_INSTANCE".to_string()
105 }
106 FetchErrorCode::Other(other) => other,
107 },
108 message: error.message,
109 },
110 }
111 }
112}
113
114pin_project! {
115 #[must_use = "streams do nothing unless polled"]
118 #[project = LicenseExpanderProj]
119 pub(crate) struct LicenseExpander<Upstream>
120 where
121 Upstream: Stream<Item = License>,
122 {
123 #[pin]
124 checks: DelayQueue<Event>,
125 #[pin]
126 upstream: Fuse<Upstream>,
127 }
128}
129
130impl<Upstream> Stream for LicenseExpander<Upstream>
131where
132 Upstream: Stream<Item = License>,
133{
134 type Item = Event;
135
136 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
137 let mut this = self.project();
138 let checks = this.checks.poll_expired(cx);
139 let next = if matches!(checks, Poll::Ready(Some(_))) {
141 None
142 } else {
143 Some(this.upstream.poll_next(cx))
145 };
146
147 match (checks, next) {
148 (Poll::Ready(Some(item)), _) => Poll::Ready(Some(item.into_inner())),
151 (_, Some(Poll::Ready(Some(license)))) if license.claims.is_some() => {
153 reset_checks_for_licenses(&mut this.checks, license)
155 }
156 (_, Some(Poll::Ready(Some(_)))) => {
158 Poll::Ready(Some(Event::UpdateLicense(Arc::new(
160 LicenseState::Unlicensed,
161 ))))
162 }
163 (Poll::Pending, _) | (_, Some(Poll::Pending)) => Poll::Pending,
167 (Poll::Ready(None), Some(Poll::Ready(None))) => Poll::Ready(None),
169 (Poll::Ready(None), None) => {
170 unreachable!("upstream will have been called as checks did not have a value")
171 }
172 }
173 }
174}
175
176fn reset_checks_for_licenses(
179 checks: &mut DelayQueue<Event>,
180 license: License,
181) -> Poll<Option<Event>> {
182 checks.clear();
184 let claims = license.claims.as_ref().expect("claims is gated, qed");
185
186 let limits = match (claims.tps, &claims.allowed_features) {
188 (None, None) => None,
189 (Some(tps_limit), Some(features)) => Some(
190 LicenseLimits::builder()
191 .tps(
192 TpsLimit::builder()
193 .capacity(tps_limit.capacity)
194 .interval(tps_limit.interval)
195 .build(),
196 )
197 .allowed_features(HashSet::from_iter(features.clone()))
198 .build(),
199 ),
200 (Some(tps_limit), None) => Some(LicenseLimits {
201 tps: Some(TpsLimit {
202 capacity: tps_limit.capacity,
203 interval: tps_limit.interval,
204 }),
205 allowed_features: HashSet::from_iter(AllowedFeature::iter()),
206 }),
207 (None, Some(features)) => Some(
208 LicenseLimits::builder()
209 .allowed_features(HashSet::from_iter(features.clone()))
210 .build(),
211 ),
212 };
213
214 let halt_at = to_positive_instant(claims.halt_at);
215 let warn_at = to_positive_instant(claims.warn_at);
216 let now = Instant::now();
217 if halt_at > now {
219 checks.insert_at(
221 Event::UpdateLicense(Arc::new(LicenseState::LicensedHalt {
222 limits: limits.clone(),
223 })),
224 (halt_at).into(),
225 );
226 } else {
227 return Poll::Ready(Some(Event::UpdateLicense(Arc::new(
228 LicenseState::LicensedHalt {
229 limits: limits.clone(),
230 },
231 ))));
232 }
233 if warn_at > now {
234 checks.insert_at(
237 Event::UpdateLicense(Arc::new(LicenseState::LicensedWarn {
238 limits: limits.clone(),
239 })),
240 (warn_at).into(),
241 );
242 } else {
243 return Poll::Ready(Some(Event::UpdateLicense(Arc::new(
244 LicenseState::LicensedWarn {
245 limits: limits.clone(),
246 },
247 ))));
248 }
249
250 Poll::Ready(Some(Event::UpdateLicense(Arc::new(
251 LicenseState::Licensed {
252 limits: limits.clone(),
253 },
254 ))))
255}
256
257fn to_positive_instant(system_time: SystemTime) -> Instant {
261 let now_instant = Instant::now();
263 let now_system_time = SystemTime::now();
264
265 match system_time.duration_since(now_system_time) {
267 Ok(duration) => now_instant + duration,
269
270 Err(_) => now_instant,
272 }
273}
274
275type ValidateAudience<T> = FilterMap<
276 Zip<T, Repeat<Arc<HashSet<Audience>>>>,
277 Ready<Option<License>>,
278 fn((License, Arc<HashSet<Audience>>)) -> Ready<Option<License>>,
279>;
280
281pub(crate) trait LicenseStreamExt: Stream<Item = License> {
282 fn expand_licenses(self) -> LicenseExpander<Self>
283 where
284 Self: Sized,
285 {
286 LicenseExpander {
287 checks: Default::default(),
288 upstream: self.fuse(),
289 }
290 }
291
292 fn validate_audience(self, audiences: impl Into<HashSet<Audience>>) -> ValidateAudience<Self>
293 where
294 Self: Sized,
295 {
296 let audiences: Arc<HashSet<Audience>> = Arc::new(audiences.into());
300 self.zip(futures::stream::repeat(audiences))
301 .filter_map(|(license, audiences)| {
302 let matches = match &license {
303 License {
304 claims:
305 Some(Claims {
306 aud: OneOrMany::Many(aud),
307 ..
308 }),
309 } => aud.iter().any(|aud| audiences.contains(aud)),
310 License {
311 claims:
312 Some(Claims {
313 aud: OneOrMany::One(aud),
314 ..
315 }),
316 } => audiences.contains(aud),
317 License { claims: None } => true,
319 };
320
321 if !matches {
322 tracing::error!(
323 code = APOLLO_ROUTER_LICENSE_OFFLINE_UNSUPPORTED,
324 "the license file was valid, but was not enabled offline use",
325 );
326 }
327 futures::future::ready(if matches { Some(license) } else { None })
328 })
329 }
330}
331
332impl<T: Stream<Item = License>> LicenseStreamExt for T {}
333
334#[cfg(test)]
335mod test {
336 use std::future::ready;
337 use std::time::Duration;
338 use std::time::Instant;
339 use std::time::SystemTime;
340
341 use futures::StreamExt;
342 use futures_test::stream::StreamTestExt;
343 use tracing::instrument::WithSubscriber;
344
345 use crate::assert_snapshot_subscriber;
346 use crate::router::Event;
347 use crate::uplink::UplinkConfig;
348 use crate::uplink::license_enforcement::Audience;
349 use crate::uplink::license_enforcement::Claims;
350 use crate::uplink::license_enforcement::License;
351 use crate::uplink::license_enforcement::LicenseState;
352 use crate::uplink::license_enforcement::OneOrMany;
353 use crate::uplink::license_stream::LicenseQuery;
354 use crate::uplink::license_stream::LicenseStreamExt;
355 use crate::uplink::license_stream::to_positive_instant;
356 use crate::uplink::stream_from_uplink;
357
358 #[tokio::test]
359 async fn integration_test() {
360 if let (Ok(apollo_key), Ok(apollo_graph_ref)) = (
361 std::env::var("TEST_APOLLO_KEY"),
362 std::env::var("TEST_APOLLO_GRAPH_REF"),
363 ) {
364 let results = stream_from_uplink::<LicenseQuery, License>(UplinkConfig {
365 apollo_key,
366 apollo_graph_ref,
367 endpoints: None,
368 poll_interval: Duration::from_secs(1),
369 timeout: Duration::from_secs(5),
370 })
371 .take(1)
372 .collect::<Vec<_>>()
373 .await;
374
375 assert!(
376 results
377 .first()
378 .expect("expected one result")
379 .as_ref()
380 .expect("license should be OK")
381 .claims
382 .is_some()
383 )
384 }
385 }
386
387 #[test]
388 fn test_to_instant() {
389 let now_system_time = SystemTime::now();
390 let now_instant = Instant::now();
391 let future_system_time = now_system_time + Duration::from_secs(1024);
392 let future_instant = to_positive_instant(future_system_time);
393 assert!(future_instant < now_instant + Duration::from_secs(1025));
394 assert!(future_instant > now_instant + Duration::from_secs(1023));
395
396 let past_system_time = now_system_time - Duration::from_secs(1024);
398 let past_instant = to_positive_instant(past_system_time);
399 assert!(past_instant > now_instant);
400 assert!(past_instant < Instant::now());
401 }
402
403 #[tokio::test]
404 async fn license_expander() {
405 let events_stream = futures::stream::iter(vec![license_with_claim(15, 30)])
406 .expand_licenses()
407 .map(SimpleEvent::from);
408
409 let events = events_stream.collect::<Vec<_>>().await;
410 assert_eq!(
411 events,
412 &[
413 SimpleEvent::UpdateLicense,
414 SimpleEvent::WarnLicense,
415 SimpleEvent::HaltLicense
416 ]
417 );
418 }
419
420 #[tokio::test]
421 async fn license_expander_warn_now() {
422 let events_stream = futures::stream::iter(vec![license_with_claim(0, 15)])
423 .interleave_pending()
424 .expand_licenses()
425 .map(SimpleEvent::from);
426
427 let events = events_stream.collect::<Vec<_>>().await;
428 assert_eq!(
429 events,
430 &[SimpleEvent::WarnLicense, SimpleEvent::HaltLicense]
431 );
432 }
433
434 #[tokio::test]
435 async fn license_expander_halt_now() {
436 let events_stream = futures::stream::iter(vec![license_with_claim(0, 0)])
437 .interleave_pending()
438 .expand_licenses()
439 .map(SimpleEvent::from);
440
441 let events = events_stream.collect::<Vec<_>>().await;
442 assert_eq!(events, &[SimpleEvent::HaltLicense]);
443 }
444
445 #[tokio::test]
446 async fn license_expander_no_claim() {
447 let events_stream = futures::stream::iter(vec![license_with_no_claim()])
448 .interleave_pending()
449 .expand_licenses()
450 .map(SimpleEvent::from);
451
452 let events = events_stream.collect::<Vec<_>>().await;
453 assert_eq!(events, &[SimpleEvent::UpdateLicense]);
454 }
455
456 #[tokio::test]
457 async fn license_expander_claim_no_claim() {
458 let events_stream =
460 futures::stream::iter(vec![license_with_claim(10, 10), license_with_no_claim()])
461 .interleave_pending()
462 .expand_licenses()
463 .map(SimpleEvent::from);
464
465 let events = events_stream.collect::<Vec<_>>().await;
466 assert_eq!(
467 events,
468 &[
469 SimpleEvent::UpdateLicense,
470 SimpleEvent::UpdateLicense,
471 SimpleEvent::WarnLicense,
472 SimpleEvent::HaltLicense
473 ]
474 );
475 }
476
477 #[tokio::test]
478 async fn license_expander_no_claim_claim() {
479 let events_stream =
480 futures::stream::iter(vec![license_with_no_claim(), license_with_claim(15, 30)])
481 .interleave_pending()
482 .expand_licenses()
483 .map(SimpleEvent::from);
484
485 let events = events_stream.collect::<Vec<_>>().await;
486 assert_eq!(
487 events,
488 &[
489 SimpleEvent::UpdateLicense,
490 SimpleEvent::UpdateLicense,
491 SimpleEvent::WarnLicense,
492 SimpleEvent::HaltLicense
493 ]
494 );
495 }
496
497 #[tokio::test(flavor = "multi_thread")]
498 async fn license_expander_claim_pause_claim() {
499 let (tx, rx) = tokio::sync::mpsc::channel(10);
500 let rx_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
501 let events_stream = rx_stream.expand_licenses().map(SimpleEvent::from);
502
503 tokio::task::spawn(async move {
504 let _ = tx.send(license_with_claim(100, 300)).await;
506 tokio::time::sleep(Duration::from_millis(200)).await;
507 let _ = tx.send(license_with_claim(100, 300)).await;
508 });
509 let events = events_stream.collect::<Vec<_>>().await;
510 assert_eq!(
511 events,
512 &[
513 SimpleEvent::UpdateLicense,
514 SimpleEvent::WarnLicense,
515 SimpleEvent::UpdateLicense,
516 SimpleEvent::WarnLicense,
517 SimpleEvent::HaltLicense
518 ]
519 );
520 }
521
522 fn license_with_claim(warn_delta: u64, halt_delta: u64) -> License {
523 let now = SystemTime::now();
524 License {
525 claims: Some(Claims {
526 iss: "".to_string(),
527 sub: "".to_string(),
528 aud: OneOrMany::One(Audience::SelfHosted),
529 warn_at: now + Duration::from_millis(warn_delta),
530 halt_at: now + Duration::from_millis(halt_delta),
531 tps: Default::default(),
532 allowed_features: Default::default(),
533 }),
534 }
535 }
536
537 fn license_with_no_claim() -> License {
538 License { claims: None }
539 }
540
541 #[derive(Eq, PartialEq, Debug)]
542 enum SimpleEvent {
543 UpdateConfiguration,
544 NoMoreConfiguration,
545 UpdateSchema,
546 NoMoreSchema,
547 UpdateLicense,
548 HaltLicense,
549 WarnLicense,
550 NoMoreLicense,
551 ForcedHotReload,
552 Shutdown,
553 }
554
555 impl From<Event> for SimpleEvent {
556 fn from(value: Event) -> Self {
557 match value {
558 Event::UpdateConfiguration(_) => SimpleEvent::UpdateConfiguration,
559 Event::NoMoreConfiguration => SimpleEvent::NoMoreConfiguration,
560 Event::UpdateSchema(_) => SimpleEvent::UpdateSchema,
561 Event::NoMoreSchema => SimpleEvent::NoMoreSchema,
562 Event::UpdateLicense(license) => match *license {
563 LicenseState::LicensedHalt { limits: _ } => SimpleEvent::HaltLicense,
564 LicenseState::LicensedWarn { limits: _ } => SimpleEvent::WarnLicense,
565 _ => SimpleEvent::UpdateLicense,
566 },
567 Event::NoMoreLicense => SimpleEvent::NoMoreLicense,
568 Event::Reload | Event::RhaiReload => SimpleEvent::ForcedHotReload,
569 Event::Shutdown => SimpleEvent::Shutdown,
570 }
571 }
572 }
573
574 #[tokio::test]
575 async fn test_validate_audience_single() {
576 assert_eq!(
577 futures::stream::once(ready(License {
578 claims: Some(Claims {
579 iss: "".to_string(),
580 sub: "".to_string(),
581 aud: OneOrMany::One(Audience::Offline),
582 warn_at: SystemTime::now(),
583 halt_at: SystemTime::now(),
584 tps: Default::default(),
585 allowed_features: Default::default(),
586 }),
587 }))
588 .validate_audience([Audience::Offline, Audience::Cloud])
589 .count()
590 .with_subscriber(assert_snapshot_subscriber!())
591 .await,
592 1
593 );
594 }
595
596 #[tokio::test]
597 async fn test_validate_audience_single_filtered() {
598 assert_eq!(
599 futures::stream::once(ready(License {
600 claims: Some(Claims {
601 iss: "".to_string(),
602 sub: "".to_string(),
603 aud: OneOrMany::One(Audience::SelfHosted),
604 warn_at: SystemTime::now(),
605 halt_at: SystemTime::now(),
606 tps: Default::default(),
607 allowed_features: Default::default(),
608 }),
609 }))
610 .validate_audience([Audience::Offline, Audience::Cloud])
611 .count()
612 .with_subscriber(assert_snapshot_subscriber!())
613 .await,
614 0
615 );
616 }
617
618 #[tokio::test]
619 async fn test_validate_audience_multiple() {
620 assert_eq!(
621 futures::stream::once(ready(License {
622 claims: Some(Claims {
623 iss: "".to_string(),
624 sub: "".to_string(),
625 aud: OneOrMany::Many(vec![Audience::SelfHosted, Audience::Offline]),
626 warn_at: SystemTime::now(),
627 halt_at: SystemTime::now(),
628 tps: Default::default(),
629 allowed_features: Default::default(),
630 }),
631 }))
632 .validate_audience([Audience::Offline, Audience::Cloud])
633 .count()
634 .with_subscriber(assert_snapshot_subscriber!())
635 .await,
636 1
637 );
638 }
639
640 #[tokio::test]
641 async fn test_validate_audience_multiple_filtered() {
642 assert_eq!(
643 futures::stream::once(ready(License {
644 claims: Some(Claims {
645 iss: "".to_string(),
646 sub: "".to_string(),
647 aud: OneOrMany::Many(vec![Audience::SelfHosted, Audience::SelfHosted]),
648 warn_at: SystemTime::now(),
649 halt_at: SystemTime::now(),
650 tps: Default::default(),
651 allowed_features: Default::default(),
652 }),
653 }))
654 .validate_audience([Audience::Offline, Audience::Cloud])
655 .count()
656 .with_subscriber(assert_snapshot_subscriber!())
657 .await,
658 0
659 );
660 }
661
662 #[tokio::test]
663 async fn test_validate_no_claim() {
664 assert_eq!(
665 futures::stream::once(ready(License::default()))
666 .validate_audience([Audience::Offline, Audience::Cloud])
667 .count()
668 .with_subscriber(assert_snapshot_subscriber!())
669 .await,
670 1
671 );
672 }
673}