Skip to main content

apollo_router/uplink/
license_enforcement.rs

1// tonic does not derive `Eq` for the gRPC message types, which causes a warning from Clippy. The
2// current suggestion is to explicitly allow the lint in the module that imports the protos.
3// Read more: https://github.com/hyperium/tonic/issues/1056
4#![allow(clippy::derive_partial_eq_without_eq)]
5
6use std::collections::HashMap;
7use std::collections::HashSet;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::str::FromStr;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14
15use apollo_compiler::schema::ExtendedType;
16use buildstructor::Builder;
17use displaydoc::Display;
18use itertools::Itertools;
19use jsonwebtoken::DecodingKey;
20use jsonwebtoken::Validation;
21use jsonwebtoken::decode;
22use jsonwebtoken::jwk::JwkSet;
23use once_cell::sync::OnceCell;
24use regex::Regex;
25use serde::Deserialize;
26use serde::Deserializer;
27use serde::Serialize;
28use serde::de::Visitor;
29use serde_json::Value;
30use strum::EnumIter;
31use strum::IntoEnumIterator;
32use thiserror::Error;
33
34use super::parsed_link_spec::ParsedLinkSpec;
35use crate::Configuration;
36use crate::plugins::authentication::jwks::convert_key_algorithm;
37use crate::spec::LINK_DIRECTIVE_NAME;
38use crate::spec::Schema;
39
40pub(crate) const LICENSE_EXPIRED_URL: &str = "https://go.apollo.dev/o/elp";
41pub(crate) const LICENSE_EXPIRED_SHORT_MESSAGE: &str =
42    "Apollo license expired https://go.apollo.dev/o/elp";
43
44pub(crate) const APOLLO_ROUTER_LICENSE_EXPIRED: &str = "APOLLO_ROUTER_LICENSE_EXPIRED";
45
46static JWKS: OnceCell<JwkSet> = OnceCell::new();
47
48#[derive(Error, Display, Debug)]
49pub enum Error {
50    /// invalid license: {0}
51    InvalidLicense(jsonwebtoken::errors::Error),
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
55#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
56pub(crate) enum Audience {
57    SelfHosted,
58    Cloud,
59    Offline,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
63#[serde(untagged)]
64pub(crate) enum OneOrMany<T> {
65    One(T),
66    Many(Vec<T>),
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
70pub(crate) struct Claims {
71    pub(crate) iss: String,
72    pub(crate) sub: String,
73    pub(crate) aud: OneOrMany<Audience>,
74    #[serde(deserialize_with = "deserialize_epoch_seconds", rename = "warnAt")]
75    /// When to warn the user about an expiring license that must be renewed to avoid halting the
76    /// router
77    pub(crate) warn_at: SystemTime,
78    #[serde(deserialize_with = "deserialize_epoch_seconds", rename = "haltAt")]
79    /// When to halt the router because of an expired license
80    pub(crate) halt_at: SystemTime,
81    /// TPS limits. These may not exist in a License; if not, no limits apply
82    #[serde(rename = "throughputLimit")]
83    pub(crate) tps: Option<TpsLimit>,
84    /// Set of allowed features. These may not exist in a License; if not, all features are enabled
85    /// NB: This is temporary behavior and will be updated once all licenses contain an allowed_features claim.
86    #[serde(rename = "allowedFeatures")]
87    pub(crate) allowed_features: Option<Vec<AllowedFeature>>,
88}
89
90fn deserialize_epoch_seconds<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
91where
92    D: Deserializer<'de>,
93{
94    let seconds = i32::deserialize(deserializer)?;
95    Ok(UNIX_EPOCH + Duration::from_secs(seconds as u64))
96}
97
98fn deserialize_ms_into_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
99where
100    D: Deserializer<'de>,
101{
102    let seconds = i32::deserialize(deserializer)?;
103    Ok(Duration::from_millis(seconds as u64))
104}
105
106#[derive(Debug)]
107pub(crate) struct LicenseEnforcementReport {
108    restricted_config_in_use: Vec<ConfigurationRestriction>,
109    restricted_schema_in_use: Vec<SchemaViolation>,
110}
111
112impl LicenseEnforcementReport {
113    pub(crate) fn uses_restricted_features(&self) -> bool {
114        !self.restricted_config_in_use.is_empty() || !self.restricted_schema_in_use.is_empty()
115    }
116
117    pub(crate) fn build(
118        configuration: &Configuration,
119        schema: &Schema,
120        license: &LicenseState,
121    ) -> LicenseEnforcementReport {
122        LicenseEnforcementReport {
123            restricted_config_in_use: Self::validate_configuration(
124                configuration,
125                &Self::configuration_restrictions(license),
126            ),
127            restricted_schema_in_use: Self::validate_schema(
128                schema,
129                &Self::schema_restrictions(license),
130            ),
131        }
132    }
133
134    pub(crate) fn restricted_features_in_use(&self) -> Vec<String> {
135        let mut restricted_features_in_use = Vec::new();
136        for restricted_config_in_use in self.restricted_config_in_use.clone() {
137            restricted_features_in_use.push(restricted_config_in_use.name.clone());
138        }
139        for restricted_schema_in_use in self.restricted_schema_in_use.clone() {
140            match restricted_schema_in_use {
141                SchemaViolation::Spec { name, .. } => {
142                    restricted_features_in_use.push(name.clone());
143                }
144                SchemaViolation::DirectiveArgument { name, .. } => {
145                    restricted_features_in_use.push(name.clone());
146                }
147            }
148        }
149        restricted_features_in_use
150    }
151
152    fn validate_configuration(
153        configuration: &Configuration,
154        configuration_restrictions: &Vec<ConfigurationRestriction>,
155    ) -> Vec<ConfigurationRestriction> {
156        let mut selector = jsonpath_lib::selector(
157            configuration
158                .validated_yaml
159                .as_ref()
160                .unwrap_or(&Value::Null),
161        );
162        let mut configuration_violations = Vec::new();
163        for restriction in configuration_restrictions {
164            if let Some(value) = selector(&restriction.path)
165                .expect("path on restriction was not valid")
166                .first()
167            {
168                if let Some(restriction_value) = &restriction.value {
169                    if *value == restriction_value {
170                        configuration_violations.push(restriction.clone());
171                    }
172                } else {
173                    configuration_violations.push(restriction.clone());
174                }
175            }
176        }
177        configuration_violations
178    }
179
180    fn validate_schema(
181        schema: &Schema,
182        schema_restrictions: &Vec<SchemaRestriction>,
183    ) -> Vec<SchemaViolation> {
184        let link_specs = schema
185            .supergraph_schema()
186            .schema_definition
187            .directives
188            .get_all(LINK_DIRECTIVE_NAME)
189            .filter_map(|link| {
190                ParsedLinkSpec::from_link_directive(link).map(|maybe_spec| {
191                    maybe_spec.ok().map(|spec| (spec.spec_url.to_owned(), spec))
192                })?
193            })
194            .collect::<HashMap<_, _>>();
195
196        let link_specs_in_join_directive = schema
197            .supergraph_schema()
198            .schema_definition
199            .directives
200            .get_all("join__directive")
201            .filter(|join| {
202                join.specified_argument_by_name("name")
203                    .and_then(|name| name.as_str())
204                    .map(|name| name == LINK_DIRECTIVE_NAME)
205                    .unwrap_or_default()
206            })
207            .filter_map(|join| {
208                join.specified_argument_by_name("args")
209                    .and_then(|arg| arg.as_object())
210            })
211            .filter_map(|link| {
212                ParsedLinkSpec::from_join_directive_args(link).map(|maybe_spec| {
213                    maybe_spec.ok().map(|spec| (spec.spec_url.to_owned(), spec))
214                })?
215            })
216            .collect::<HashMap<_, _>>();
217
218        let mut schema_violations: Vec<SchemaViolation> = Vec::new();
219
220        for restriction in schema_restrictions {
221            match restriction {
222                SchemaRestriction::Spec {
223                    spec_url,
224                    name,
225                    version_req,
226                } => {
227                    if let Some(link_spec) = link_specs.get(spec_url)
228                        && version_req.matches(&link_spec.version)
229                    {
230                        schema_violations.push(SchemaViolation::Spec {
231                            url: link_spec.url.to_string(),
232                            name: name.to_string(),
233                        });
234                    }
235                }
236                SchemaRestriction::DirectiveArgument {
237                    spec_url,
238                    name,
239                    version_req,
240                    argument,
241                    explanation,
242                } => {
243                    if let Some(link_spec) = link_specs.get(spec_url)
244                        && version_req.matches(&link_spec.version)
245                    {
246                        let directive_name = link_spec.directive_name(name);
247                        if schema
248                            .supergraph_schema()
249                            .types
250                            .values()
251                            .flat_map(|def| match def {
252                                // To traverse additional directive locations, add match arms for the respective definition types required.
253                                ExtendedType::Object(object_type_def) => {
254                                    let directives_on_object = object_type_def
255                                        .directives
256                                        .get_all(&directive_name)
257                                        .map(|component| &component.node);
258                                    let directives_on_fields =
259                                        object_type_def.fields.values().flat_map(|field| {
260                                            field.directives.get_all(&directive_name)
261                                        });
262
263                                    directives_on_object
264                                        .chain(directives_on_fields)
265                                        .collect::<Vec<_>>()
266                                }
267                                _ => vec![],
268                            })
269                            .any(|directive| {
270                                directive.specified_argument_by_name(argument).is_some()
271                            })
272                        {
273                            schema_violations.push(SchemaViolation::DirectiveArgument {
274                                url: link_spec.url.to_string(),
275                                name: directive_name.to_string(),
276                                argument: argument.to_string(),
277                                explanation: explanation.to_string(),
278                            });
279                        }
280                    }
281                }
282                SchemaRestriction::SpecInJoinDirective {
283                    spec_url,
284                    name,
285                    version_req,
286                } => {
287                    if let Some(link_spec) = link_specs_in_join_directive.get(spec_url)
288                        && version_req.matches(&link_spec.version)
289                    {
290                        schema_violations.push(SchemaViolation::Spec {
291                            url: link_spec.url.to_string(),
292                            name: name.to_string(),
293                        });
294                    }
295                }
296            }
297        }
298
299        schema_violations
300    }
301
302    fn configuration_restrictions(license: &LicenseState) -> Vec<ConfigurationRestriction> {
303        let mut configuration_restrictions = vec![];
304
305        let allowed_features = license.get_allowed_features();
306        if !allowed_features.contains(&AllowedFeature::ApqCaching) {
307            configuration_restrictions.push(
308                ConfigurationRestriction::builder()
309                    .path("$.apq.router.cache.redis")
310                    .name("APQ caching")
311                    .build(),
312            )
313        }
314        if !allowed_features.contains(&AllowedFeature::Authentication) {
315            configuration_restrictions.push(
316                ConfigurationRestriction::builder()
317                    .path("$.authentication.router")
318                    .name("Authentication plugin")
319                    .build(),
320            );
321        }
322        if !allowed_features.contains(&AllowedFeature::Authorization) {
323            configuration_restrictions.push(
324                ConfigurationRestriction::builder()
325                    .path("$.authorization.directives")
326                    .name("Authorization directives")
327                    .build(),
328            );
329        }
330        if !allowed_features.contains(&AllowedFeature::Batching) {
331            configuration_restrictions.push(
332                ConfigurationRestriction::builder()
333                    .path("$.batching")
334                    .name("Batching support")
335                    .build(),
336            );
337        }
338        if !allowed_features.contains(&AllowedFeature::EntityCaching) {
339            configuration_restrictions.push(
340                ConfigurationRestriction::builder()
341                    .path("$.preview_entity_cache.enabled")
342                    .value(true)
343                    .name("Subgraph entity caching")
344                    .build(),
345            );
346        }
347        if !allowed_features.contains(&AllowedFeature::ResponseCaching) {
348            configuration_restrictions.push(
349                ConfigurationRestriction::builder()
350                    .path("$.response_cache.enabled")
351                    .value(true)
352                    .name("Subgraph response caching")
353                    .build(),
354            );
355        }
356        if !allowed_features.contains(&AllowedFeature::PersistedQueries) {
357            configuration_restrictions.push(
358                ConfigurationRestriction::builder()
359                    .path("$.persisted_queries")
360                    .name("Persisted queries")
361                    .build(),
362            );
363        }
364        if !allowed_features.contains(&AllowedFeature::Subscriptions) {
365            configuration_restrictions.push(
366                ConfigurationRestriction::builder()
367                    .path("$.subscription.enabled")
368                    .value(true)
369                    .name("Federated subscriptions")
370                    .build(),
371            );
372        }
373        if !allowed_features.contains(&AllowedFeature::Coprocessors) {
374            configuration_restrictions.push(
375                ConfigurationRestriction::builder()
376                    .path("$.coprocessor")
377                    .name("Coprocessor plugin")
378                    .build(),
379            )
380        }
381        if !allowed_features.contains(&AllowedFeature::DistributedQueryPlanning) {
382            configuration_restrictions.push(
383                ConfigurationRestriction::builder()
384                    .path("$.supergraph.query_planning.cache.redis")
385                    .name("Query plan caching")
386                    .build(),
387            )
388        }
389        if !allowed_features.contains(&AllowedFeature::DemandControl) {
390            configuration_restrictions.push(
391                ConfigurationRestriction::builder()
392                    .path("$.demand_control")
393                    .name("Demand control plugin")
394                    .build(),
395            );
396        }
397        if !allowed_features.contains(&AllowedFeature::Experimental) {
398            configuration_restrictions.push(
399                ConfigurationRestriction::builder()
400                    .path("$.plugins.['experimental.restricted'].enabled")
401                    .value(true)
402                    .name("Restricted")
403                    .build(),
404            );
405        }
406        // Per-operation limits are restricted but parser limits like `parser_max_recursion`
407        // where the Router only configures apollo-rs are not.
408        if !allowed_features.contains(&AllowedFeature::RequestLimits) {
409            configuration_restrictions.extend(vec![
410                ConfigurationRestriction::builder()
411                    .path("$.limits.router.max_depth")
412                    .name("Operation depth limiting")
413                    .build(),
414                ConfigurationRestriction::builder()
415                    .path("$.limits.router.max_height")
416                    .name("Operation height limiting")
417                    .build(),
418                ConfigurationRestriction::builder()
419                    .path("$.limits.router.max_root_fields")
420                    .name("Operation root fields limiting")
421                    .build(),
422                ConfigurationRestriction::builder()
423                    .path("$.limits.router.max_aliases")
424                    .name("Operation aliases limiting")
425                    .build(),
426            ]);
427        }
428
429        configuration_restrictions
430    }
431
432    fn schema_restrictions(license: &LicenseState) -> Vec<SchemaRestriction> {
433        let mut schema_restrictions = vec![];
434        let allowed_features = license.get_allowed_features();
435
436        if !allowed_features.contains(&AllowedFeature::Authentication) {
437            schema_restrictions.push(SchemaRestriction::Spec {
438                name: "authenticated".to_string(),
439                spec_url: "https://specs.apollo.dev/authenticated".to_string(),
440                version_req: semver::VersionReq {
441                    comparators: vec![semver::Comparator {
442                        op: semver::Op::Exact,
443                        major: 0,
444                        minor: 1.into(),
445                        patch: 0.into(),
446                        pre: semver::Prerelease::EMPTY,
447                    }],
448                },
449            });
450            schema_restrictions.push(SchemaRestriction::Spec {
451                name: "requiresScopes".to_string(),
452                spec_url: "https://specs.apollo.dev/requiresScopes".to_string(),
453                version_req: semver::VersionReq {
454                    comparators: vec![semver::Comparator {
455                        op: semver::Op::Exact,
456                        major: 0,
457                        minor: 1.into(),
458                        patch: 0.into(),
459                        pre: semver::Prerelease::EMPTY,
460                    }],
461                },
462            });
463        }
464
465        schema_restrictions
466    }
467}
468
469impl Display for LicenseEnforcementReport {
470    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
471        if !self.restricted_config_in_use.is_empty() {
472            let restricted_config = self
473                .restricted_config_in_use
474                .iter()
475                .map(|v| format!("* {}\n  {}", v.name, v.path.replace("$.", ".")))
476                .join("\n\n");
477            write!(f, "Configuration yaml:\n{restricted_config}")?;
478
479            if !self.restricted_schema_in_use.is_empty() {
480                writeln!(f)?;
481            }
482        }
483
484        if !self.restricted_schema_in_use.is_empty() {
485            let restricted_schema = self
486                .restricted_schema_in_use
487                .iter()
488                .map(|v| v.to_string())
489                .join("\n\n");
490
491            write!(f, "Schema features:\n{restricted_schema}")?
492        }
493
494        Ok(())
495    }
496}
497
498/// Claims extracted from the License, including ways Apollo limits the router's usage. It must be constructed from a base64 encoded JWT
499/// This API experimental and is subject to change outside of semver.
500#[derive(Debug, Clone, Default)]
501pub struct License {
502    pub(crate) claims: Option<Claims>,
503}
504
505/// Transactions Per Second limits. We talk as though this will be in seconds, but the Duration
506/// here is actually given to us in milliseconds via the License's JWT's claims
507#[derive(Builder, Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
508pub(crate) struct TpsLimit {
509    pub(crate) capacity: usize,
510
511    #[serde(
512        deserialize_with = "deserialize_ms_into_duration",
513        rename = "durationMs"
514    )]
515    pub(crate) interval: Duration,
516}
517
518/// Allowed features for a License, representing what's available to a particular pricing tier
519#[derive(Clone, Debug, Eq, PartialEq, Serialize, Hash, EnumIter)]
520#[serde(rename_all = "snake_case")]
521pub enum AllowedFeature {
522    /// Automated persisted queries
523    Apq,
524    /// APQ caching
525    ApqCaching,
526    /// Authentication plugin
527    Authentication,
528    /// Authorization directives
529    Authorization,
530    /// Batching support
531    Batching,
532    /// Coprocessor plugin
533    Coprocessors,
534    /// Demand control plugin
535    DemandControl,
536    /// Distributed query planning
537    DistributedQueryPlanning,
538    /// Subgraph entity caching
539    EntityCaching,
540    /// Subgraph response caching
541    ResponseCaching,
542    /// Experimental features in the router
543    Experimental,
544    /// Extended reference reporting
545    ExtendedReferenceReporting,
546    /// Persisted queries safelisting
547    PersistedQueries,
548    /// Request limits - depth and breadth
549    RequestLimits,
550    /// Federated subscriptions
551    Subscriptions,
552    /// Traffic shaping
553    TrafficShaping,
554    /// This represents a feature found in the license that the router does not recognize
555    Other(String),
556}
557
558impl From<&str> for AllowedFeature {
559    fn from(feature: &str) -> Self {
560        match feature {
561            "apq" => Self::Apq,
562            "apq_caching" => Self::ApqCaching,
563            "authentication" => Self::Authentication,
564            "authorization" => Self::Authorization,
565            "batching" => Self::Batching,
566            "coprocessors" => Self::Coprocessors,
567            "demand_control" => Self::DemandControl,
568            "distributed_query_planning" => Self::DistributedQueryPlanning,
569            "entity_caching" => Self::EntityCaching,
570            "response_caching" => Self::ResponseCaching,
571            "experimental" => Self::Experimental,
572            "extended_reference_reporting" => Self::ExtendedReferenceReporting,
573            "persisted_queries" => Self::PersistedQueries,
574            "request_limits" => Self::RequestLimits,
575            "subscriptions" => Self::Subscriptions,
576            "traffic_shaping" => Self::TrafficShaping,
577            other => Self::Other(other.into()),
578        }
579    }
580}
581
582impl AllowedFeature {
583    /// Creates an allowed feature from a plugin name
584    pub fn from_plugin_name(plugin_name: &str) -> Option<AllowedFeature> {
585        match plugin_name {
586            "traffic_shaping" => Some(AllowedFeature::TrafficShaping),
587            "limits" => Some(AllowedFeature::RequestLimits),
588            "subscription" => Some(AllowedFeature::Subscriptions),
589            "authorization" => Some(AllowedFeature::Authorization),
590            "authentication" => Some(AllowedFeature::Authentication),
591            "preview_entity_cache" => Some(AllowedFeature::EntityCaching),
592            "response_cache" => Some(AllowedFeature::ResponseCaching),
593            "demand_control" => Some(AllowedFeature::DemandControl),
594            "coprocessor" => Some(AllowedFeature::Coprocessors),
595            _other => None,
596        }
597    }
598}
599
600impl<'de> Deserialize<'de> for AllowedFeature {
601    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
602    where
603        D: Deserializer<'de>,
604    {
605        struct AllowedFeatureVisitor;
606
607        impl<'de> Visitor<'de> for AllowedFeatureVisitor {
608            type Value = AllowedFeature;
609
610            fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
611                formatter.write_str("a string representing an allowed feature")
612            }
613
614            fn visit_str<E>(self, value: &str) -> Result<AllowedFeature, E>
615            where
616                E: serde::de::Error,
617            {
618                Ok(AllowedFeature::from(value))
619            }
620        }
621
622        deserializer.deserialize_str(AllowedFeatureVisitor)
623    }
624}
625
626/// LicenseLimits represent what can be done with a router based on the claims in the License. You
627/// might have a certain tier be limited in its capacity for transactions over a certain duration,
628/// as an example
629#[derive(Debug, Builder, Clone, Eq, PartialEq)]
630pub struct LicenseLimits {
631    /// Transaction Per Second limits. If none are found in the License's claims, there are no
632    /// limits to apply
633    pub(crate) tps: Option<TpsLimit>,
634    /// The allowed features based on the allowed features present on the License's claims
635    pub(crate) allowed_features: HashSet<AllowedFeature>,
636}
637
638impl Default for LicenseLimits {
639    fn default() -> Self {
640        Self {
641            tps: None,
642            allowed_features: HashSet::from_iter(AllowedFeature::iter()),
643        }
644    }
645}
646
647/// Licenses are converted into a stream of license states by the expander
648#[derive(Debug, Clone, Eq, PartialEq, Default, Display)]
649pub enum LicenseState {
650    /// licensed
651    Licensed { limits: Option<LicenseLimits> },
652    /// warn
653    LicensedWarn { limits: Option<LicenseLimits> },
654    /// halt
655    LicensedHalt { limits: Option<LicenseLimits> },
656
657    /// unlicensed
658    #[default]
659    Unlicensed,
660}
661
662impl LicenseState {
663    pub(crate) fn is_unlicensed(&self) -> bool {
664        matches!(self, LicenseState::Unlicensed)
665    }
666
667    pub(crate) fn is_licensed(&self) -> bool {
668        !self.is_unlicensed()
669    }
670
671    pub(crate) fn get_limits(&self) -> Option<&LicenseLimits> {
672        match self {
673            LicenseState::Licensed { limits }
674            | LicenseState::LicensedWarn { limits }
675            | LicenseState::LicensedHalt { limits } => limits.as_ref(),
676            _ => None,
677        }
678    }
679
680    pub(crate) fn get_allowed_features(&self) -> HashSet<AllowedFeature> {
681        match self {
682            LicenseState::Licensed { limits }
683            | LicenseState::LicensedWarn { limits }
684            | LicenseState::LicensedHalt { limits } => {
685                match limits {
686                    Some(limits) => limits.allowed_features.clone(),
687                    // If the license has no limits and therefore no allowed_features claim,
688                    // we're using a pricing plan that should have the feature enabled regardless.
689                    // NB: This is temporary behavior and will be updated once all licenses contain
690                    // an allowed_features claim.
691                    None => HashSet::from_iter(AllowedFeature::iter()),
692                }
693            }
694            // If we are using an expired license or an unlicesed router we return an empty feature set
695            LicenseState::Unlicensed => HashSet::new(),
696        }
697    }
698
699    pub(crate) fn get_name(&self) -> &'static str {
700        match self {
701            Self::Licensed { limits: _ } => "Licensed",
702            Self::LicensedWarn { limits: _ } => "LicensedWarn",
703            Self::LicensedHalt { limits: _ } => "LicensedHalt",
704            Self::Unlicensed => "Unlicensed",
705        }
706    }
707}
708
709impl Display for License {
710    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
711        if let Some(claims) = &self.claims {
712            write!(
713                f,
714                "{}",
715                serde_json::to_string(claims)
716                    .unwrap_or_else(|_| "claim serialization error".to_string())
717            )
718        } else {
719            write!(f, "no license")
720        }
721    }
722}
723
724impl FromStr for License {
725    type Err = Error;
726
727    fn from_str(jwt: &str) -> Result<Self, Self::Err> {
728        Self::jwks()
729            .keys
730            .iter()
731            .map(|jwk| {
732                // Set up the validation for the JWT.
733                // We don't require exp as we are only interested in haltAt and warnAt
734                let mut validation = Validation::new(
735                    convert_key_algorithm(
736                        jwk.common
737                            .key_algorithm
738                            .expect("alg is required on all keys in router.jwks.json"),
739                    )
740                    .expect("only signing algorithms are used"),
741                );
742                validation.validate_exp = false;
743                validation.set_required_spec_claims(&["iss", "sub", "aud", "warnAt", "haltAt"]);
744                validation.set_issuer(&["https://www.apollographql.com/"]);
745                validation.set_audience(&["CLOUD", "SELF_HOSTED", "OFFLINE"]);
746
747                decode::<Claims>(
748                    jwt.trim(),
749                    &DecodingKey::from_jwk(jwk).expect("router.jwks.json must be valid"),
750                    &validation,
751                )
752                .map_err(Error::InvalidLicense)
753                .map(|r| License {
754                    claims: Some(r.claims),
755                })
756            })
757            .find_or_last(|r| r.is_ok())
758            .transpose()
759            .map(|e| {
760                let e = e.unwrap_or_default();
761                tracing::debug!("decoded license {jwt}->{e}");
762                e
763            })
764    }
765}
766
767/// An individual check for the router.yaml.
768#[derive(Builder, Clone, Debug, Serialize, Deserialize)]
769pub(crate) struct ConfigurationRestriction {
770    name: String,
771    path: String,
772    value: Option<Value>,
773}
774
775// An individual check for the supergraph schema
776// #[derive(Builder, Clone, Debug, Serialize, Deserialize)]
777// pub(crate) struct SchemaRestriction {
778//     name: String,
779//     url: String,
780// }
781
782/// An individual check for the supergraph schema
783#[derive(Clone, Debug)]
784pub(crate) enum SchemaRestriction {
785    Spec {
786        spec_url: String,
787        name: String,
788        version_req: semver::VersionReq,
789    },
790    // Note: this restriction is currently unused, but it's intention was to
791    // traverse directives belonging to object types and their fields. It was used for
792    // progressive overrides when they were gated to enterprise-only. Leaving it here for now
793    // in case other directives become gated by subscription tier (there's at least one in the
794    // works that's non-free)
795    #[allow(dead_code)]
796    DirectiveArgument {
797        spec_url: String,
798        name: String,
799        version_req: semver::VersionReq,
800        argument: String,
801        explanation: String,
802    },
803    // Note: this restriction is currently unused.
804    // It was used for connectors when they were gated to license-only. Leaving it here for now
805    // in case other directives become gated by subscription tier
806    #[allow(dead_code)]
807    SpecInJoinDirective {
808        spec_url: String,
809        name: String,
810        version_req: semver::VersionReq,
811    },
812}
813
814#[derive(Debug, Clone, Serialize, Deserialize)]
815pub(crate) enum SchemaViolation {
816    Spec {
817        url: String,
818        name: String,
819    },
820    DirectiveArgument {
821        url: String,
822        name: String,
823        argument: String,
824        explanation: String,
825    },
826}
827
828impl Display for SchemaViolation {
829    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
830        match self {
831            SchemaViolation::Spec { name, url } => {
832                write!(f, "* @{name}\n  {url}")
833            }
834            SchemaViolation::DirectiveArgument {
835                name,
836                url,
837                argument,
838                explanation,
839            } => {
840                write!(f, "* @{name}.{argument}\n  {url}\n\n{explanation}")
841            }
842        }
843    }
844}
845
846impl License {
847    pub(crate) fn jwks() -> &'static JwkSet {
848        JWKS.get_or_init(|| {
849            // Strip the comments from the top of the file.
850            let re = Regex::new("(?m)^//.*$").expect("regex must be valid");
851            // We have a set of test JWTs that use this dummy JWKS endpoint. See the internal docs
852            // of the router team for details on how to mint a dummy JWT for testing
853            let jwks = if let Ok(jwks_path) = std::env::var("APOLLO_TEST_INTERNAL_UPLINK_JWKS") {
854                tracing::debug!("using a dummy JWKS endpoint: {jwks_path:?}");
855                let jwks = std::fs::read_to_string(jwks_path)
856                    .expect("dummy JWKS endpoint couldn't be read into memory");
857                re.replace(&jwks, "").into_owned()
858            } else {
859                re.replace(include_str!("license.jwks.json"), "")
860                    .into_owned()
861            };
862
863            serde_json::from_str::<JwkSet>(&jwks).expect("router jwks must be valid")
864        })
865    }
866}
867
868#[cfg(test)]
869mod test {
870    use std::collections::HashSet;
871    use std::str::FromStr;
872    use std::time::Duration;
873    use std::time::UNIX_EPOCH;
874
875    use insta::assert_snapshot;
876    use serde_json::json;
877
878    use crate::AllowedFeature;
879    use crate::Configuration;
880    use crate::spec::Schema;
881    use crate::uplink::license_enforcement::Audience;
882    use crate::uplink::license_enforcement::Claims;
883    use crate::uplink::license_enforcement::License;
884    use crate::uplink::license_enforcement::LicenseEnforcementReport;
885    use crate::uplink::license_enforcement::LicenseLimits;
886    use crate::uplink::license_enforcement::LicenseState;
887    use crate::uplink::license_enforcement::OneOrMany;
888
889    #[track_caller]
890    fn check(
891        router_yaml: &str,
892        supergraph_schema: &str,
893        license: LicenseState,
894    ) -> LicenseEnforcementReport {
895        let config = Configuration::from_str(router_yaml).expect("router config must be valid");
896        let schema =
897            Schema::parse(supergraph_schema, &config).expect("supergraph schema must be valid");
898
899        LicenseEnforcementReport::build(&config, &schema, &license)
900    }
901
902    #[test]
903    fn test_oss() {
904        let report = check(
905            include_str!("testdata/oss.router.yaml"),
906            include_str!("testdata/oss.graphql"),
907            LicenseState::default(),
908        );
909
910        assert!(
911            report.restricted_config_in_use.is_empty(),
912            "should not have found restricted features"
913        );
914    }
915
916    #[test]
917    fn test_restricted_features_via_config_unlicensed() {
918        let report = check(
919            include_str!("testdata/restricted.router.yaml"),
920            include_str!("testdata/oss.graphql"),
921            LicenseState::default(),
922        );
923
924        assert!(
925            !report.restricted_config_in_use.is_empty(),
926            "should have found restricted features"
927        );
928        assert_snapshot!(report.to_string());
929    }
930
931    #[test]
932    fn test_restricted_features_via_config_allowed_features_empty() {
933        let report = check(
934            include_str!("testdata/restricted.router.yaml"),
935            include_str!("testdata/oss.graphql"),
936            LicenseState::Licensed {
937                limits: Some(LicenseLimits {
938                    tps: None,
939                    allowed_features: HashSet::from_iter(vec![]),
940                }),
941            },
942        );
943
944        assert!(
945            !report.restricted_config_in_use.is_empty(),
946            "should have found restricted features"
947        );
948        assert_snapshot!(report.to_string());
949    }
950
951    #[test]
952    fn test_restricted_features_via_config_with_allowed_features() {
953        // The config includes subscriptions but the license's
954        // allowed_features claim does not include subscriptions
955        let report = check(
956            include_str!("testdata/restricted.router.yaml"),
957            include_str!("testdata/oss.graphql"),
958            LicenseState::Licensed {
959                limits: Some(LicenseLimits {
960                    tps: None,
961                    allowed_features: HashSet::from_iter(vec![
962                        AllowedFeature::Authentication,
963                        AllowedFeature::Authorization,
964                        AllowedFeature::Batching,
965                        AllowedFeature::DemandControl,
966                        AllowedFeature::EntityCaching,
967                        AllowedFeature::PersistedQueries,
968                        AllowedFeature::ApqCaching,
969                    ]),
970                }),
971            },
972        );
973
974        assert!(
975            !report.restricted_config_in_use.is_empty(),
976            "should have found restricted features"
977        );
978        assert_snapshot!(report.to_string());
979    }
980
981    #[test]
982    fn test_restricted_authorization_directives_via_schema_unlicensed() {
983        let report = check(
984            include_str!("testdata/oss.router.yaml"),
985            include_str!("testdata/authorization.graphql"),
986            LicenseState::default(),
987        );
988
989        assert!(
990            !report.restricted_schema_in_use.is_empty(),
991            "should have found restricted features"
992        );
993        assert_snapshot!(report.to_string());
994    }
995
996    #[test]
997    fn test_restricted_authorization_directives_via_schema_with_restricted_allowed_features() {
998        // When auth is contained within the allowed features set
999        // we should not find any schema violations in the report
1000        let report = check(
1001            include_str!("testdata/oss.router.yaml"),
1002            include_str!("testdata/authorization.graphql"),
1003            LicenseState::Licensed {
1004                limits: Some(LicenseLimits {
1005                    tps: None,
1006                    allowed_features: HashSet::from_iter(vec![
1007                        AllowedFeature::Authentication,
1008                        AllowedFeature::Authorization,
1009                    ]),
1010                }),
1011            },
1012        );
1013        assert!(
1014            report.restricted_schema_in_use.is_empty(),
1015            "should have not found restricted features"
1016        );
1017
1018        // When auth is not contained within the allowed features set
1019        // we should find schema violations in the report
1020        let report = check(
1021            include_str!("testdata/oss.router.yaml"),
1022            include_str!("testdata/authorization.graphql"),
1023            LicenseState::Licensed {
1024                limits: Some(LicenseLimits {
1025                    tps: None,
1026                    allowed_features: HashSet::from_iter(vec![AllowedFeature::DemandControl]),
1027                }),
1028            },
1029        );
1030        assert!(
1031            !report.restricted_schema_in_use.is_empty(),
1032            "should have found restricted features"
1033        );
1034        assert_snapshot!(report.to_string());
1035    }
1036
1037    // NB: this behavior will change once all licenses have an `allowed_features` claim
1038    #[test]
1039    fn test_restricted_authorization_directives_via_schema_with_default_license_limits() {
1040        let report = check(
1041            include_str!("testdata/oss.router.yaml"),
1042            include_str!("testdata/authorization.graphql"),
1043            LicenseState::Licensed {
1044                limits: Default::default(),
1045            },
1046        );
1047
1048        assert!(
1049            report.restricted_schema_in_use.is_empty(),
1050            "should have not found restricted features"
1051        );
1052    }
1053
1054    #[test]
1055    #[cfg(not(windows))] // http::uri::Uri parsing appears to reject unix:// on Windows
1056    fn unix_socket_available_to_oss() {
1057        let report = check(
1058            include_str!("testdata/oss.router.yaml"),
1059            include_str!("testdata/unix_socket.graphql"),
1060            LicenseState::default(),
1061        );
1062
1063        assert!(
1064            report.restricted_schema_in_use.is_empty(),
1065            "shouldn't have found restricted features"
1066        );
1067    }
1068
1069    #[test]
1070    fn schema_enforcement_allows_context_directive_for_oss() {
1071        let report = check(
1072            include_str!("testdata/oss.router.yaml"),
1073            include_str!("testdata/set_context.graphql"),
1074            LicenseState::default(),
1075        );
1076
1077        assert!(
1078            report.restricted_schema_in_use.is_empty(),
1079            "shouldn't have found restricted features"
1080        );
1081    }
1082
1083    #[test]
1084    #[cfg(not(windows))] // http::uri::Uri parsing appears to reject unix:// on Windows
1085    fn test_restricted_unix_socket_via_schema_when_allowed_features_empty() {
1086        let report = check(
1087            include_str!("testdata/oss.router.yaml"),
1088            include_str!("testdata/unix_socket.graphql"),
1089            LicenseState::Licensed {
1090                limits: Some(LicenseLimits {
1091                    tps: None,
1092                    allowed_features: HashSet::new(),
1093                }),
1094            },
1095        );
1096
1097        assert!(
1098            report.restricted_schema_in_use.is_empty(),
1099            "shouldn't have found restricted features"
1100        );
1101    }
1102
1103    #[test]
1104    fn test_license_parse() {
1105        let license = License::from_str("eyJhbGciOiJFZERTQSJ9.eyJpc3MiOiJodHRwczovL3d3dy5hcG9sbG9ncmFwaHFsLmNvbS8iLCJzdWIiOiJhcG9sbG8iLCJhdWQiOiJTRUxGX0hPU1RFRCIsIndhcm5BdCI6MTY3NjgwODAwMCwiaGFsdEF0IjoxNjc4MDE3NjAwfQ.tXexfjZ2SQeqSwkWQ7zD4XBoxS_Hc5x7tSNJ3ln-BCL_GH7i3U9hsIgdRQTczCAjA_jjk34w39DeSV0nTc5WBw").expect("must be able to decode JWT"); // gitleaks:allow
1106
1107        assert_eq!(
1108            license.claims,
1109            Some(Claims {
1110                iss: "https://www.apollographql.com/".to_string(),
1111                sub: "apollo".to_string(),
1112                aud: OneOrMany::One(Audience::SelfHosted),
1113                warn_at: UNIX_EPOCH + Duration::from_secs(1676808000),
1114                halt_at: UNIX_EPOCH + Duration::from_secs(1678017600),
1115                tps: Default::default(),
1116                allowed_features: Default::default()
1117            }),
1118        );
1119    }
1120
1121    #[test]
1122    fn test_license_parse_with_whitespace() {
1123        let license = License::from_str("   eyJhbGciOiJFZERTQSJ9.eyJpc3MiOiJodHRwczovL3d3dy5hcG9sbG9ncmFwaHFsLmNvbS8iLCJzdWIiOiJhcG9sbG8iLCJhdWQiOiJTRUxGX0hPU1RFRCIsIndhcm5BdCI6MTY3NjgwODAwMCwiaGFsdEF0IjoxNjc4MDE3NjAwfQ.tXexfjZ2SQeqSwkWQ7zD4XBoxS_Hc5x7tSNJ3ln-BCL_GH7i3U9hsIgdRQTczCAjA_jjk34w39DeSV0nTc5WBw\n ").expect("must be able to decode JWT"); // gitleaks:allow
1124        assert_eq!(
1125            license.claims,
1126            Some(Claims {
1127                iss: "https://www.apollographql.com/".to_string(),
1128                sub: "apollo".to_string(),
1129                aud: OneOrMany::One(Audience::SelfHosted),
1130                warn_at: UNIX_EPOCH + Duration::from_secs(1676808000),
1131                halt_at: UNIX_EPOCH + Duration::from_secs(1678017600),
1132                tps: Default::default(),
1133                allowed_features: Default::default()
1134            }),
1135        );
1136    }
1137
1138    #[test]
1139    fn test_license_parse_fail() {
1140        License::from_str("invalid").expect_err("jwt must fail parse");
1141    }
1142
1143    #[test]
1144    fn claims_serde() {
1145        serde_json::from_value::<Claims>(json!({
1146            "iss": "Issuer",
1147            "sub": "Subject",
1148            "aud": "CLOUD",
1149            "warnAt": 122,
1150            "haltAt": 123,
1151        }))
1152        .expect("json must deserialize");
1153
1154        serde_json::from_value::<Claims>(json!({
1155            "iss": "Issuer",
1156            "sub": "Subject",
1157            "aud": ["CLOUD", "SELF_HOSTED"],
1158            "warnAt": 122,
1159            "haltAt": 123,
1160        }))
1161        .expect("json must deserialize");
1162
1163        serde_json::from_value::<Claims>(json!({
1164            "iss": "Issuer",
1165            "sub": "Subject",
1166            "aud": "OFFLINE",
1167            "warnAt": 122,
1168            "haltAt": 123,
1169        }))
1170        .expect("json must deserialize");
1171
1172        serde_json::from_value::<Claims>(json!({
1173            "iss": "Issuer",
1174            "sub": "Subject",
1175            "aud": "OFFLINE",
1176            "warnAt": 122,
1177            "haltAt": 123,
1178            "allowedFeatures": ["SUBSCRIPTIONS", "ENTITY_CACHING"]
1179        }))
1180        .expect("json must deserialize");
1181    }
1182
1183    #[test]
1184    fn progressive_override_available_to_oss() {
1185        let report = check(
1186            include_str!("testdata/oss.router.yaml"),
1187            include_str!("testdata/progressive_override.graphql"),
1188            LicenseState::default(),
1189        );
1190
1191        // progressive override is available for oss
1192        assert!(
1193            report.restricted_schema_in_use.is_empty(),
1194            "shouldn't have found restricted features"
1195        );
1196    }
1197
1198    #[test]
1199    fn set_context() {
1200        let report = check(
1201            include_str!("testdata/oss.router.yaml"),
1202            include_str!("testdata/set_context.graphql"),
1203            LicenseState::default(),
1204        );
1205
1206        assert!(
1207            report.restricted_schema_in_use.is_empty(),
1208            "shouldn't have found restricted features"
1209        );
1210    }
1211
1212    #[test]
1213    fn progressive_override_with_renamed_join_spec() {
1214        let report = check(
1215            include_str!("testdata/oss.router.yaml"),
1216            include_str!("testdata/progressive_override_renamed_join.graphql"),
1217            LicenseState::default(),
1218        );
1219
1220        assert!(
1221            report.restricted_schema_in_use.is_empty(),
1222            "shouldn't have found restricted features"
1223        );
1224    }
1225
1226    #[test]
1227    fn schema_enforcement_spec_version_in_range() {
1228        let report = check(
1229            include_str!("testdata/oss.router.yaml"),
1230            include_str!("testdata/schema_enforcement_spec_version_in_range.graphql"),
1231            LicenseState::default(),
1232        );
1233
1234        assert!(
1235            !report.restricted_schema_in_use.is_empty(),
1236            "should have found restricted features"
1237        );
1238        assert_snapshot!(report.to_string());
1239    }
1240
1241    #[test]
1242    fn schema_enforcement_spec_version_out_of_range() {
1243        let report = check(
1244            include_str!("testdata/oss.router.yaml"),
1245            include_str!("testdata/schema_enforcement_spec_version_out_of_range.graphql"),
1246            LicenseState::default(),
1247        );
1248
1249        assert!(
1250            report.restricted_schema_in_use.is_empty(),
1251            "shouldn't have found restricted features"
1252        );
1253    }
1254
1255    #[test]
1256    fn schema_enforcement_directive_arg_version_in_range() {
1257        let report = check(
1258            include_str!("testdata/oss.router.yaml"),
1259            include_str!("testdata/schema_enforcement_directive_arg_version_in_range.graphql"),
1260            LicenseState::default(),
1261        );
1262
1263        assert!(
1264            report.restricted_schema_in_use.is_empty(),
1265            "shouldn't have found restricted features"
1266        );
1267    }
1268
1269    #[test]
1270    fn schema_enforcement_directive_arg_version_out_of_range() {
1271        let report = check(
1272            include_str!("testdata/oss.router.yaml"),
1273            include_str!("testdata/schema_enforcement_directive_arg_version_out_of_range.graphql"),
1274            LicenseState::default(),
1275        );
1276
1277        assert!(
1278            report.restricted_schema_in_use.is_empty(),
1279            "shouldn't have found restricted features"
1280        );
1281    }
1282
1283    #[test]
1284    fn schema_enforcement_connectors() {
1285        let report = check(
1286            include_str!("testdata/oss.router.yaml"),
1287            include_str!("testdata/schema_enforcement_connectors.graphql"),
1288            LicenseState::default(),
1289        );
1290
1291        assert!(
1292            report.restricted_schema_in_use.is_empty(),
1293            "shouldn't have found restricted connect feature"
1294        );
1295    }
1296}