apollo_router/uplink/
license_enforcement.rs

1// tonic does not derive `Eq` for the gRPC message types, which causes a warning from Clippy. The
2// current suggestion is to explicitly allow the lint in the module that imports the protos.
3// Read more: https://github.com/hyperium/tonic/issues/1056
4#![allow(clippy::derive_partial_eq_without_eq)]
5
6use std::collections::HashMap;
7use std::collections::HashSet;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::str::FromStr;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14
15use apollo_compiler::schema::ExtendedType;
16use buildstructor::Builder;
17use displaydoc::Display;
18use itertools::Itertools;
19use jsonwebtoken::DecodingKey;
20use jsonwebtoken::Validation;
21use jsonwebtoken::decode;
22use jsonwebtoken::jwk::JwkSet;
23use once_cell::sync::OnceCell;
24use regex::Regex;
25use serde::Deserialize;
26use serde::Deserializer;
27use serde::Serialize;
28use serde::de::Visitor;
29use serde_json::Value;
30use strum::IntoEnumIterator;
31use strum_macros::EnumIter;
32use thiserror::Error;
33
34use super::parsed_link_spec::ParsedLinkSpec;
35use crate::Configuration;
36use crate::plugins::authentication::jwks::convert_key_algorithm;
37use crate::spec::LINK_DIRECTIVE_NAME;
38use crate::spec::Schema;
39
40pub(crate) const LICENSE_EXPIRED_URL: &str = "https://go.apollo.dev/o/elp";
41pub(crate) const LICENSE_EXPIRED_SHORT_MESSAGE: &str =
42    "Apollo license expired https://go.apollo.dev/o/elp";
43
44pub(crate) const APOLLO_ROUTER_LICENSE_EXPIRED: &str = "APOLLO_ROUTER_LICENSE_EXPIRED";
45
46static JWKS: OnceCell<JwkSet> = OnceCell::new();
47
48#[derive(Error, Display, Debug)]
49pub enum Error {
50    /// invalid license: {0}
51    InvalidLicense(jsonwebtoken::errors::Error),
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
55#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
56pub(crate) enum Audience {
57    SelfHosted,
58    Cloud,
59    Offline,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
63#[serde(untagged)]
64pub(crate) enum OneOrMany<T> {
65    One(T),
66    Many(Vec<T>),
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
70pub(crate) struct Claims {
71    pub(crate) iss: String,
72    pub(crate) sub: String,
73    pub(crate) aud: OneOrMany<Audience>,
74    #[serde(deserialize_with = "deserialize_epoch_seconds", rename = "warnAt")]
75    /// When to warn the user about an expiring license that must be renewed to avoid halting the
76    /// router
77    pub(crate) warn_at: SystemTime,
78    #[serde(deserialize_with = "deserialize_epoch_seconds", rename = "haltAt")]
79    /// When to halt the router because of an expired license
80    pub(crate) halt_at: SystemTime,
81    /// TPS limits. These may not exist in a License; if not, no limits apply
82    #[serde(rename = "throughputLimit")]
83    pub(crate) tps: Option<TpsLimit>,
84    /// Set of allowed features. These may not exist in a License; if not, all features are enabled
85    /// NB: This is temporary behavior and will be updated once all licenses contain an allowed_features claim.
86    #[serde(rename = "allowedFeatures")]
87    pub(crate) allowed_features: Option<Vec<AllowedFeature>>,
88}
89
90fn deserialize_epoch_seconds<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
91where
92    D: Deserializer<'de>,
93{
94    let seconds = i32::deserialize(deserializer)?;
95    Ok(UNIX_EPOCH + Duration::from_secs(seconds as u64))
96}
97
98fn deserialize_ms_into_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
99where
100    D: Deserializer<'de>,
101{
102    let seconds = i32::deserialize(deserializer)?;
103    Ok(Duration::from_millis(seconds as u64))
104}
105
106#[derive(Debug)]
107pub(crate) struct LicenseEnforcementReport {
108    restricted_config_in_use: Vec<ConfigurationRestriction>,
109    restricted_schema_in_use: Vec<SchemaViolation>,
110}
111
112impl LicenseEnforcementReport {
113    pub(crate) fn uses_restricted_features(&self) -> bool {
114        !self.restricted_config_in_use.is_empty() || !self.restricted_schema_in_use.is_empty()
115    }
116
117    pub(crate) fn build(
118        configuration: &Configuration,
119        schema: &Schema,
120        license: &LicenseState,
121    ) -> LicenseEnforcementReport {
122        LicenseEnforcementReport {
123            restricted_config_in_use: Self::validate_configuration(
124                configuration,
125                &Self::configuration_restrictions(license),
126            ),
127            restricted_schema_in_use: Self::validate_schema(
128                schema,
129                &Self::schema_restrictions(license),
130            ),
131        }
132    }
133
134    pub(crate) fn restricted_features_in_use(&self) -> Vec<String> {
135        let mut restricted_features_in_use = Vec::new();
136        for restricted_config_in_use in self.restricted_config_in_use.clone() {
137            restricted_features_in_use.push(restricted_config_in_use.name.clone());
138        }
139        for restricted_schema_in_use in self.restricted_schema_in_use.clone() {
140            match restricted_schema_in_use {
141                SchemaViolation::Spec { name, .. } => {
142                    restricted_features_in_use.push(name.clone());
143                }
144                SchemaViolation::DirectiveArgument { name, .. } => {
145                    restricted_features_in_use.push(name.clone());
146                }
147            }
148        }
149        restricted_features_in_use
150    }
151
152    fn validate_configuration(
153        configuration: &Configuration,
154        configuration_restrictions: &Vec<ConfigurationRestriction>,
155    ) -> Vec<ConfigurationRestriction> {
156        let mut selector = jsonpath_lib::selector(
157            configuration
158                .validated_yaml
159                .as_ref()
160                .unwrap_or(&Value::Null),
161        );
162        let mut configuration_violations = Vec::new();
163        for restriction in configuration_restrictions {
164            if let Some(value) = selector(&restriction.path)
165                .expect("path on restriction was not valid")
166                .first()
167            {
168                if let Some(restriction_value) = &restriction.value {
169                    if *value == restriction_value {
170                        configuration_violations.push(restriction.clone());
171                    }
172                } else {
173                    configuration_violations.push(restriction.clone());
174                }
175            }
176        }
177        configuration_violations
178    }
179
180    fn validate_schema(
181        schema: &Schema,
182        schema_restrictions: &Vec<SchemaRestriction>,
183    ) -> Vec<SchemaViolation> {
184        let link_specs = schema
185            .supergraph_schema()
186            .schema_definition
187            .directives
188            .get_all(LINK_DIRECTIVE_NAME)
189            .filter_map(|link| {
190                ParsedLinkSpec::from_link_directive(link).map(|maybe_spec| {
191                    maybe_spec.ok().map(|spec| (spec.spec_url.to_owned(), spec))
192                })?
193            })
194            .collect::<HashMap<_, _>>();
195
196        let link_specs_in_join_directive = schema
197            .supergraph_schema()
198            .schema_definition
199            .directives
200            .get_all("join__directive")
201            .filter(|join| {
202                join.specified_argument_by_name("name")
203                    .and_then(|name| name.as_str())
204                    .map(|name| name == LINK_DIRECTIVE_NAME)
205                    .unwrap_or_default()
206            })
207            .filter_map(|join| {
208                join.specified_argument_by_name("args")
209                    .and_then(|arg| arg.as_object())
210            })
211            .filter_map(|link| {
212                ParsedLinkSpec::from_join_directive_args(link).map(|maybe_spec| {
213                    maybe_spec.ok().map(|spec| (spec.spec_url.to_owned(), spec))
214                })?
215            })
216            .collect::<HashMap<_, _>>();
217
218        let mut schema_violations: Vec<SchemaViolation> = Vec::new();
219
220        for restriction in schema_restrictions {
221            match restriction {
222                SchemaRestriction::Spec {
223                    spec_url,
224                    name,
225                    version_req,
226                } => {
227                    if let Some(link_spec) = link_specs.get(spec_url)
228                        && version_req.matches(&link_spec.version)
229                    {
230                        schema_violations.push(SchemaViolation::Spec {
231                            url: link_spec.url.to_string(),
232                            name: name.to_string(),
233                        });
234                    }
235                }
236                SchemaRestriction::DirectiveArgument {
237                    spec_url,
238                    name,
239                    version_req,
240                    argument,
241                    explanation,
242                } => {
243                    if let Some(link_spec) = link_specs.get(spec_url)
244                        && version_req.matches(&link_spec.version)
245                    {
246                        let directive_name = link_spec.directive_name(name);
247                        if schema
248                            .supergraph_schema()
249                            .types
250                            .values()
251                            .flat_map(|def| match def {
252                                // To traverse additional directive locations, add match arms for the respective definition types required.
253                                ExtendedType::Object(object_type_def) => {
254                                    let directives_on_object = object_type_def
255                                        .directives
256                                        .get_all(&directive_name)
257                                        .map(|component| &component.node);
258                                    let directives_on_fields =
259                                        object_type_def.fields.values().flat_map(|field| {
260                                            field.directives.get_all(&directive_name)
261                                        });
262
263                                    directives_on_object
264                                        .chain(directives_on_fields)
265                                        .collect::<Vec<_>>()
266                                }
267                                _ => vec![],
268                            })
269                            .any(|directive| {
270                                directive.specified_argument_by_name(argument).is_some()
271                            })
272                        {
273                            schema_violations.push(SchemaViolation::DirectiveArgument {
274                                url: link_spec.url.to_string(),
275                                name: directive_name.to_string(),
276                                argument: argument.to_string(),
277                                explanation: explanation.to_string(),
278                            });
279                        }
280                    }
281                }
282                SchemaRestriction::SpecInJoinDirective {
283                    spec_url,
284                    name,
285                    version_req,
286                } => {
287                    if let Some(link_spec) = link_specs_in_join_directive.get(spec_url)
288                        && version_req.matches(&link_spec.version)
289                    {
290                        schema_violations.push(SchemaViolation::Spec {
291                            url: link_spec.url.to_string(),
292                            name: name.to_string(),
293                        });
294                    }
295                }
296            }
297        }
298
299        schema_violations
300    }
301
302    fn configuration_restrictions(license: &LicenseState) -> Vec<ConfigurationRestriction> {
303        let mut configuration_restrictions = vec![];
304
305        let allowed_features = license.get_allowed_features();
306        if !allowed_features.contains(&AllowedFeature::ApqCaching) {
307            configuration_restrictions.push(
308                ConfigurationRestriction::builder()
309                    .path("$.apq.router.cache.redis")
310                    .name("APQ caching")
311                    .build(),
312            )
313        }
314        if !allowed_features.contains(&AllowedFeature::Authentication) {
315            configuration_restrictions.push(
316                ConfigurationRestriction::builder()
317                    .path("$.authentication.router")
318                    .name("Authentication plugin")
319                    .build(),
320            );
321        }
322        if !allowed_features.contains(&AllowedFeature::Authorization) {
323            configuration_restrictions.push(
324                ConfigurationRestriction::builder()
325                    .path("$.authorization.directives")
326                    .name("Authorization directives")
327                    .build(),
328            );
329        }
330        if !allowed_features.contains(&AllowedFeature::Batching) {
331            configuration_restrictions.push(
332                ConfigurationRestriction::builder()
333                    .path("$.batching")
334                    .name("Batching support")
335                    .build(),
336            );
337        }
338        if !allowed_features.contains(&AllowedFeature::EntityCaching) {
339            configuration_restrictions.push(
340                ConfigurationRestriction::builder()
341                    .path("$.preview_entity_cache.enabled")
342                    .value(true)
343                    .name("Subgraph entity caching")
344                    .build(),
345            );
346        }
347        if !allowed_features.contains(&AllowedFeature::ResponseCaching) {
348            configuration_restrictions.push(
349                ConfigurationRestriction::builder()
350                    .path("$.preview_response_cache.enabled")
351                    .value(true)
352                    .name("Subgraph response caching")
353                    .build(),
354            );
355        }
356        if !allowed_features.contains(&AllowedFeature::PersistedQueries) {
357            configuration_restrictions.push(
358                ConfigurationRestriction::builder()
359                    .path("$.persisted_queries")
360                    .name("Persisted queries")
361                    .build(),
362            );
363        }
364        if !allowed_features.contains(&AllowedFeature::Subscriptions) {
365            configuration_restrictions.push(
366                ConfigurationRestriction::builder()
367                    .path("$.subscription.enabled")
368                    .value(true)
369                    .name("Federated subscriptions")
370                    .build(),
371            );
372        }
373        if !allowed_features.contains(&AllowedFeature::Coprocessors) {
374            configuration_restrictions.push(
375                ConfigurationRestriction::builder()
376                    .path("$.coprocessor")
377                    .name("Coprocessor plugin")
378                    .build(),
379            )
380        }
381        if !allowed_features.contains(&AllowedFeature::DistributedQueryPlanning) {
382            configuration_restrictions.push(
383                ConfigurationRestriction::builder()
384                    .path("$.supergraph.query_planning.cache.redis")
385                    .name("Query plan caching")
386                    .build(),
387            )
388        }
389        if !allowed_features.contains(&AllowedFeature::DemandControl) {
390            configuration_restrictions.push(
391                ConfigurationRestriction::builder()
392                    .path("$.demand_control")
393                    .name("Demand control plugin")
394                    .build(),
395            );
396        }
397        if !allowed_features.contains(&AllowedFeature::Experimental) {
398            configuration_restrictions.push(
399                ConfigurationRestriction::builder()
400                    .path("$.plugins.['experimental.restricted'].enabled")
401                    .value(true)
402                    .name("Restricted")
403                    .build(),
404            );
405        }
406        // Per-operation limits are restricted but parser limits like `parser_max_recursion`
407        // where the Router only configures apollo-rs are not.
408        if !allowed_features.contains(&AllowedFeature::RequestLimits) {
409            configuration_restrictions.extend(vec![
410                ConfigurationRestriction::builder()
411                    .path("$.limits.max_depth")
412                    .name("Operation depth limiting")
413                    .build(),
414                ConfigurationRestriction::builder()
415                    .path("$.limits.max_height")
416                    .name("Operation height limiting")
417                    .build(),
418                ConfigurationRestriction::builder()
419                    .path("$.limits.max_root_fields")
420                    .name("Operation root fields limiting")
421                    .build(),
422                ConfigurationRestriction::builder()
423                    .path("$.limits.max_aliases")
424                    .name("Operation aliases limiting")
425                    .build(),
426            ]);
427        }
428
429        configuration_restrictions
430    }
431
432    fn schema_restrictions(license: &LicenseState) -> Vec<SchemaRestriction> {
433        let mut schema_restrictions = vec![];
434        let allowed_features = license.get_allowed_features();
435
436        if !allowed_features.contains(&AllowedFeature::Authentication) {
437            schema_restrictions.push(SchemaRestriction::Spec {
438                name: "authenticated".to_string(),
439                spec_url: "https://specs.apollo.dev/authenticated".to_string(),
440                version_req: semver::VersionReq {
441                    comparators: vec![semver::Comparator {
442                        op: semver::Op::Exact,
443                        major: 0,
444                        minor: 1.into(),
445                        patch: 0.into(),
446                        pre: semver::Prerelease::EMPTY,
447                    }],
448                },
449            });
450            schema_restrictions.push(SchemaRestriction::Spec {
451                name: "requiresScopes".to_string(),
452                spec_url: "https://specs.apollo.dev/requiresScopes".to_string(),
453                version_req: semver::VersionReq {
454                    comparators: vec![semver::Comparator {
455                        op: semver::Op::Exact,
456                        major: 0,
457                        minor: 1.into(),
458                        patch: 0.into(),
459                        pre: semver::Prerelease::EMPTY,
460                    }],
461                },
462            });
463        }
464
465        schema_restrictions
466    }
467}
468
469impl Display for LicenseEnforcementReport {
470    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
471        if !self.restricted_config_in_use.is_empty() {
472            let restricted_config = self
473                .restricted_config_in_use
474                .iter()
475                .map(|v| format!("* {}\n  {}", v.name, v.path.replace("$.", ".")))
476                .join("\n\n");
477            write!(f, "Configuration yaml:\n{restricted_config}")?;
478
479            if !self.restricted_schema_in_use.is_empty() {
480                writeln!(f)?;
481            }
482        }
483
484        if !self.restricted_schema_in_use.is_empty() {
485            let restricted_schema = self
486                .restricted_schema_in_use
487                .iter()
488                .map(|v| v.to_string())
489                .join("\n\n");
490
491            write!(f, "Schema features:\n{restricted_schema}")?
492        }
493
494        Ok(())
495    }
496}
497
498/// Claims extracted from the License, including ways Apollo limits the router's usage. It must be constructed from a base64 encoded JWT
499/// This API experimental and is subject to change outside of semver.
500#[derive(Debug, Clone, Default)]
501pub struct License {
502    pub(crate) claims: Option<Claims>,
503}
504
505/// Transactions Per Second limits. We talk as though this will be in seconds, but the Duration
506/// here is actually given to us in milliseconds via the License's JWT's claims
507#[derive(Builder, Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
508pub(crate) struct TpsLimit {
509    pub(crate) capacity: usize,
510
511    #[serde(
512        deserialize_with = "deserialize_ms_into_duration",
513        rename = "durationMs"
514    )]
515    pub(crate) interval: Duration,
516}
517
518/// Allowed features for a License, representing what's available to a particular pricing tier
519#[derive(Clone, Debug, Eq, PartialEq, Serialize, Hash, EnumIter)]
520#[serde(rename_all = "snake_case")]
521pub enum AllowedFeature {
522    /// Automated persisted queries
523    Apq,
524    /// APQ caching
525    ApqCaching,
526    /// Authentication plugin
527    Authentication,
528    /// Authorization directives
529    Authorization,
530    /// Batching support
531    Batching,
532    /// Coprocessor plugin
533    Coprocessors,
534    /// Demand control plugin
535    DemandControl,
536    /// Distributed query planning
537    DistributedQueryPlanning,
538    /// Subgraph entity caching
539    EntityCaching,
540    /// Subgraph response caching
541    ResponseCaching,
542    /// Experimental features in the router
543    Experimental,
544    /// Extended reference reporting
545    ExtendedReferenceReporting,
546    /// Persisted queries safelisting
547    PersistedQueries,
548    /// Request limits - depth and breadth
549    RequestLimits,
550    /// Federated subscriptions
551    Subscriptions,
552    /// Traffic shaping
553    TrafficShaping,
554    /// This represents a feature found in the license that the router does not recognize
555    Other(String),
556}
557
558impl From<&str> for AllowedFeature {
559    fn from(feature: &str) -> Self {
560        match feature {
561            "apq" => Self::Apq,
562            "apq_caching" => Self::ApqCaching,
563            "authentication" => Self::Authentication,
564            "authorization" => Self::Authorization,
565            "batching" => Self::Batching,
566            "coprocessors" => Self::Coprocessors,
567            "demand_control" => Self::DemandControl,
568            "distributed_query_planning" => Self::DistributedQueryPlanning,
569            "entity_caching" => Self::EntityCaching,
570            "response_caching" => Self::ResponseCaching,
571            "experimental" => Self::Experimental,
572            "extended_reference_reporting" => Self::ExtendedReferenceReporting,
573            "persisted_queries" => Self::PersistedQueries,
574            "request_limits" => Self::RequestLimits,
575            "subscriptions" => Self::Subscriptions,
576            "traffic_shaping" => Self::TrafficShaping,
577            other => Self::Other(other.into()),
578        }
579    }
580}
581
582impl AllowedFeature {
583    /// Creates an allowed feature from a plugin name
584    pub fn from_plugin_name(plugin_name: &str) -> Option<AllowedFeature> {
585        match plugin_name {
586            "traffic_shaping" => Some(AllowedFeature::TrafficShaping),
587            "limits" => Some(AllowedFeature::RequestLimits),
588            "subscription" => Some(AllowedFeature::Subscriptions),
589            "authorization" => Some(AllowedFeature::Authorization),
590            "authentication" => Some(AllowedFeature::Authentication),
591            "preview_entity_cache" => Some(AllowedFeature::EntityCaching),
592            "preview_response_cache" => Some(AllowedFeature::ResponseCaching),
593            "demand_control" => Some(AllowedFeature::DemandControl),
594            "coprocessor" => Some(AllowedFeature::Coprocessors),
595            _other => None,
596        }
597    }
598}
599
600impl<'de> Deserialize<'de> for AllowedFeature {
601    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
602    where
603        D: Deserializer<'de>,
604    {
605        struct AllowedFeatureVisitor;
606
607        impl<'de> Visitor<'de> for AllowedFeatureVisitor {
608            type Value = AllowedFeature;
609
610            fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
611                formatter.write_str("a string representing an allowed feature")
612            }
613
614            fn visit_str<E>(self, value: &str) -> Result<AllowedFeature, E>
615            where
616                E: serde::de::Error,
617            {
618                Ok(AllowedFeature::from(value))
619            }
620        }
621
622        deserializer.deserialize_str(AllowedFeatureVisitor)
623    }
624}
625
626/// LicenseLimits represent what can be done with a router based on the claims in the License. You
627/// might have a certain tier be limited in its capacity for transactions over a certain duration,
628/// as an example
629#[derive(Debug, Builder, Clone, Eq, PartialEq)]
630pub struct LicenseLimits {
631    /// Transaction Per Second limits. If none are found in the License's claims, there are no
632    /// limits to apply
633    pub(crate) tps: Option<TpsLimit>,
634    /// The allowed features based on the allowed features present on the License's claims
635    pub(crate) allowed_features: HashSet<AllowedFeature>,
636}
637
638impl Default for LicenseLimits {
639    fn default() -> Self {
640        Self {
641            tps: None,
642            allowed_features: HashSet::from_iter(AllowedFeature::iter()),
643        }
644    }
645}
646
647/// Licenses are converted into a stream of license states by the expander
648#[derive(Debug, Clone, Eq, PartialEq, Default, Display)]
649pub enum LicenseState {
650    /// licensed
651    Licensed { limits: Option<LicenseLimits> },
652    /// warn
653    LicensedWarn { limits: Option<LicenseLimits> },
654    /// halt
655    LicensedHalt { limits: Option<LicenseLimits> },
656
657    /// unlicensed
658    #[default]
659    Unlicensed,
660}
661
662impl LicenseState {
663    pub(crate) fn get_limits(&self) -> Option<&LicenseLimits> {
664        match self {
665            LicenseState::Licensed { limits }
666            | LicenseState::LicensedWarn { limits }
667            | LicenseState::LicensedHalt { limits } => limits.as_ref(),
668            _ => None,
669        }
670    }
671
672    pub(crate) fn get_allowed_features(&self) -> HashSet<AllowedFeature> {
673        match self {
674            LicenseState::Licensed { limits }
675            | LicenseState::LicensedWarn { limits }
676            | LicenseState::LicensedHalt { limits } => {
677                match limits {
678                    Some(limits) => limits.allowed_features.clone(),
679                    // If the license has no limits and therefore no allowed_features claim,
680                    // we're using a pricing plan that should have the feature enabled regardless.
681                    // NB: This is temporary behavior and will be updated once all licenses contain
682                    // an allowed_features claim.
683                    None => HashSet::from_iter(AllowedFeature::iter()),
684                }
685            }
686            // If we are using an expired license or an unlicesed router we return an empty feature set
687            LicenseState::Unlicensed => HashSet::new(),
688        }
689    }
690
691    pub(crate) fn get_name(&self) -> &'static str {
692        match self {
693            Self::Licensed { limits: _ } => "Licensed",
694            Self::LicensedWarn { limits: _ } => "LicensedWarn",
695            Self::LicensedHalt { limits: _ } => "LicensedHalt",
696            Self::Unlicensed => "Unlicensed",
697        }
698    }
699}
700
701impl Display for License {
702    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
703        if let Some(claims) = &self.claims {
704            write!(
705                f,
706                "{}",
707                serde_json::to_string(claims)
708                    .unwrap_or_else(|_| "claim serialization error".to_string())
709            )
710        } else {
711            write!(f, "no license")
712        }
713    }
714}
715
716impl FromStr for License {
717    type Err = Error;
718
719    fn from_str(jwt: &str) -> Result<Self, Self::Err> {
720        Self::jwks()
721            .keys
722            .iter()
723            .map(|jwk| {
724                // Set up the validation for the JWT.
725                // We don't require exp as we are only interested in haltAt and warnAt
726                let mut validation = Validation::new(
727                    convert_key_algorithm(
728                        jwk.common
729                            .key_algorithm
730                            .expect("alg is required on all keys in router.jwks.json"),
731                    )
732                    .expect("only signing algorithms are used"),
733                );
734                validation.validate_exp = false;
735                validation.set_required_spec_claims(&["iss", "sub", "aud", "warnAt", "haltAt"]);
736                validation.set_issuer(&["https://www.apollographql.com/"]);
737                validation.set_audience(&["CLOUD", "SELF_HOSTED", "OFFLINE"]);
738
739                decode::<Claims>(
740                    jwt.trim(),
741                    &DecodingKey::from_jwk(jwk).expect("router.jwks.json must be valid"),
742                    &validation,
743                )
744                .map_err(Error::InvalidLicense)
745                .map(|r| License {
746                    claims: Some(r.claims),
747                })
748            })
749            .find_or_last(|r| r.is_ok())
750            .transpose()
751            .map(|e| {
752                let e = e.unwrap_or_default();
753                tracing::debug!("decoded license {jwt}->{e}");
754                e
755            })
756    }
757}
758
759/// An individual check for the router.yaml.
760#[derive(Builder, Clone, Debug, Serialize, Deserialize)]
761pub(crate) struct ConfigurationRestriction {
762    name: String,
763    path: String,
764    value: Option<Value>,
765}
766
767// An individual check for the supergraph schema
768// #[derive(Builder, Clone, Debug, Serialize, Deserialize)]
769// pub(crate) struct SchemaRestriction {
770//     name: String,
771//     url: String,
772// }
773
774/// An individual check for the supergraph schema
775#[derive(Clone, Debug)]
776pub(crate) enum SchemaRestriction {
777    Spec {
778        spec_url: String,
779        name: String,
780        version_req: semver::VersionReq,
781    },
782    // Note: this restriction is currently unused, but it's intention was to
783    // traverse directives belonging to object types and their fields. It was used for
784    // progressive overrides when they were gated to enterprise-only. Leaving it here for now
785    // in case other directives become gated by subscription tier (there's at least one in the
786    // works that's non-free)
787    #[allow(dead_code)]
788    DirectiveArgument {
789        spec_url: String,
790        name: String,
791        version_req: semver::VersionReq,
792        argument: String,
793        explanation: String,
794    },
795    // Note: this restriction is currently unused.
796    // It was used for connectors when they were gated to license-only. Leaving it here for now
797    // in case other directives become gated by subscription tier
798    #[allow(dead_code)]
799    SpecInJoinDirective {
800        spec_url: String,
801        name: String,
802        version_req: semver::VersionReq,
803    },
804}
805
806#[derive(Debug, Clone, Serialize, Deserialize)]
807pub(crate) enum SchemaViolation {
808    Spec {
809        url: String,
810        name: String,
811    },
812    DirectiveArgument {
813        url: String,
814        name: String,
815        argument: String,
816        explanation: String,
817    },
818}
819
820impl Display for SchemaViolation {
821    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
822        match self {
823            SchemaViolation::Spec { name, url } => {
824                write!(f, "* @{name}\n  {url}")
825            }
826            SchemaViolation::DirectiveArgument {
827                name,
828                url,
829                argument,
830                explanation,
831            } => {
832                write!(f, "* @{name}.{argument}\n  {url}\n\n{explanation}")
833            }
834        }
835    }
836}
837
838impl License {
839    pub(crate) fn jwks() -> &'static JwkSet {
840        JWKS.get_or_init(|| {
841            // Strip the comments from the top of the file.
842            let re = Regex::new("(?m)^//.*$").expect("regex must be valid");
843            // We have a set of test JWTs that use this dummy JWKS endpoint. See the internal docs
844            // of the router team for details on how to mint a dummy JWT for testing
845            let jwks = if let Ok(jwks_path) = std::env::var("APOLLO_TEST_INTERNAL_UPLINK_JWKS") {
846                tracing::debug!("using a dummy JWKS endpoint: {jwks_path:?}");
847                let jwks = std::fs::read_to_string(jwks_path)
848                    .expect("dummy JWKS endpoint couldn't be read into memory");
849                re.replace(&jwks, "").into_owned()
850            } else {
851                re.replace(include_str!("license.jwks.json"), "")
852                    .into_owned()
853            };
854
855            serde_json::from_str::<JwkSet>(&jwks).expect("router jwks must be valid")
856        })
857    }
858}
859
860#[cfg(test)]
861mod test {
862    use std::collections::HashSet;
863    use std::str::FromStr;
864    use std::time::Duration;
865    use std::time::UNIX_EPOCH;
866
867    use insta::assert_snapshot;
868    use serde_json::json;
869
870    use crate::AllowedFeature;
871    use crate::Configuration;
872    use crate::spec::Schema;
873    use crate::uplink::license_enforcement::Audience;
874    use crate::uplink::license_enforcement::Claims;
875    use crate::uplink::license_enforcement::License;
876    use crate::uplink::license_enforcement::LicenseEnforcementReport;
877    use crate::uplink::license_enforcement::LicenseLimits;
878    use crate::uplink::license_enforcement::LicenseState;
879    use crate::uplink::license_enforcement::OneOrMany;
880
881    #[track_caller]
882    fn check(
883        router_yaml: &str,
884        supergraph_schema: &str,
885        license: LicenseState,
886    ) -> LicenseEnforcementReport {
887        let config = Configuration::from_str(router_yaml).expect("router config must be valid");
888        let schema =
889            Schema::parse(supergraph_schema, &config).expect("supergraph schema must be valid");
890
891        LicenseEnforcementReport::build(&config, &schema, &license)
892    }
893
894    #[test]
895    fn test_oss() {
896        let report = check(
897            include_str!("testdata/oss.router.yaml"),
898            include_str!("testdata/oss.graphql"),
899            LicenseState::default(),
900        );
901
902        assert!(
903            report.restricted_config_in_use.is_empty(),
904            "should not have found restricted features"
905        );
906    }
907
908    #[test]
909    fn test_restricted_features_via_config_unlicensed() {
910        let report = check(
911            include_str!("testdata/restricted.router.yaml"),
912            include_str!("testdata/oss.graphql"),
913            LicenseState::default(),
914        );
915
916        assert!(
917            !report.restricted_config_in_use.is_empty(),
918            "should have found restricted features"
919        );
920        assert_snapshot!(report.to_string());
921    }
922
923    #[test]
924    fn test_restricted_features_via_config_allowed_features_empty() {
925        let report = check(
926            include_str!("testdata/restricted.router.yaml"),
927            include_str!("testdata/oss.graphql"),
928            LicenseState::Licensed {
929                limits: Some(LicenseLimits {
930                    tps: None,
931                    allowed_features: HashSet::from_iter(vec![]),
932                }),
933            },
934        );
935
936        assert!(
937            !report.restricted_config_in_use.is_empty(),
938            "should have found restricted features"
939        );
940        assert_snapshot!(report.to_string());
941    }
942
943    #[test]
944    fn test_restricted_features_via_config_with_allowed_features() {
945        // The config includes subscriptions but the license's
946        // allowed_features claim does not include subscriptions
947        let report = check(
948            include_str!("testdata/restricted.router.yaml"),
949            include_str!("testdata/oss.graphql"),
950            LicenseState::Licensed {
951                limits: Some(LicenseLimits {
952                    tps: None,
953                    allowed_features: HashSet::from_iter(vec![
954                        AllowedFeature::Authentication,
955                        AllowedFeature::Authorization,
956                        AllowedFeature::Batching,
957                        AllowedFeature::DemandControl,
958                        AllowedFeature::EntityCaching,
959                        AllowedFeature::PersistedQueries,
960                        AllowedFeature::ApqCaching,
961                    ]),
962                }),
963            },
964        );
965
966        assert!(
967            !report.restricted_config_in_use.is_empty(),
968            "should have found restricted features"
969        );
970        assert_snapshot!(report.to_string());
971    }
972
973    #[test]
974    fn test_restricted_authorization_directives_via_schema_unlicensed() {
975        let report = check(
976            include_str!("testdata/oss.router.yaml"),
977            include_str!("testdata/authorization.graphql"),
978            LicenseState::default(),
979        );
980
981        assert!(
982            !report.restricted_schema_in_use.is_empty(),
983            "should have found restricted features"
984        );
985        assert_snapshot!(report.to_string());
986    }
987
988    #[test]
989    fn test_restricted_authorization_directives_via_schema_with_restricted_allowed_features() {
990        // When auth is contained within the allowed features set
991        // we should not find any schema violations in the report
992        let report = check(
993            include_str!("testdata/oss.router.yaml"),
994            include_str!("testdata/authorization.graphql"),
995            LicenseState::Licensed {
996                limits: Some(LicenseLimits {
997                    tps: None,
998                    allowed_features: HashSet::from_iter(vec![
999                        AllowedFeature::Authentication,
1000                        AllowedFeature::Authorization,
1001                    ]),
1002                }),
1003            },
1004        );
1005        assert!(
1006            report.restricted_schema_in_use.is_empty(),
1007            "should have not found restricted features"
1008        );
1009
1010        // When auth is not contained within the allowed features set
1011        // we should find schema violations in the report
1012        let report = check(
1013            include_str!("testdata/oss.router.yaml"),
1014            include_str!("testdata/authorization.graphql"),
1015            LicenseState::Licensed {
1016                limits: Some(LicenseLimits {
1017                    tps: None,
1018                    allowed_features: HashSet::from_iter(vec![AllowedFeature::DemandControl]),
1019                }),
1020            },
1021        );
1022        assert!(
1023            !report.restricted_schema_in_use.is_empty(),
1024            "should have found restricted features"
1025        );
1026        assert_snapshot!(report.to_string());
1027    }
1028
1029    // NB: this behavior will change once all licenses have an `allowed_features` claim
1030    #[test]
1031    fn test_restricted_authorization_directives_via_schema_with_default_license_limits() {
1032        let report = check(
1033            include_str!("testdata/oss.router.yaml"),
1034            include_str!("testdata/authorization.graphql"),
1035            LicenseState::Licensed {
1036                limits: Default::default(),
1037            },
1038        );
1039
1040        assert!(
1041            report.restricted_schema_in_use.is_empty(),
1042            "should have not found restricted features"
1043        );
1044    }
1045
1046    #[test]
1047    #[cfg(not(windows))] // http::uri::Uri parsing appears to reject unix:// on Windows
1048    fn unix_socket_available_to_oss() {
1049        let report = check(
1050            include_str!("testdata/oss.router.yaml"),
1051            include_str!("testdata/unix_socket.graphql"),
1052            LicenseState::default(),
1053        );
1054
1055        assert!(
1056            report.restricted_schema_in_use.is_empty(),
1057            "shouldn't have found restricted features"
1058        );
1059    }
1060
1061    #[test]
1062    fn schema_enforcement_allows_context_directive_for_oss() {
1063        let report = check(
1064            include_str!("testdata/oss.router.yaml"),
1065            include_str!("testdata/set_context.graphql"),
1066            LicenseState::default(),
1067        );
1068
1069        assert!(
1070            report.restricted_schema_in_use.is_empty(),
1071            "shouldn't have found restricted features"
1072        );
1073    }
1074
1075    #[test]
1076    #[cfg(not(windows))] // http::uri::Uri parsing appears to reject unix:// on Windows
1077    fn test_restricted_unix_socket_via_schema_when_allowed_features_empty() {
1078        let report = check(
1079            include_str!("testdata/oss.router.yaml"),
1080            include_str!("testdata/unix_socket.graphql"),
1081            LicenseState::Licensed {
1082                limits: Some(LicenseLimits {
1083                    tps: None,
1084                    allowed_features: HashSet::new(),
1085                }),
1086            },
1087        );
1088
1089        assert!(
1090            report.restricted_schema_in_use.is_empty(),
1091            "shouldn't have found restricted features"
1092        );
1093    }
1094
1095    #[test]
1096    fn test_license_parse() {
1097        let license = License::from_str("eyJhbGciOiJFZERTQSJ9.eyJpc3MiOiJodHRwczovL3d3dy5hcG9sbG9ncmFwaHFsLmNvbS8iLCJzdWIiOiJhcG9sbG8iLCJhdWQiOiJTRUxGX0hPU1RFRCIsIndhcm5BdCI6MTY3NjgwODAwMCwiaGFsdEF0IjoxNjc4MDE3NjAwfQ.tXexfjZ2SQeqSwkWQ7zD4XBoxS_Hc5x7tSNJ3ln-BCL_GH7i3U9hsIgdRQTczCAjA_jjk34w39DeSV0nTc5WBw").expect("must be able to decode JWT"); // gitleaks:allow
1098
1099        assert_eq!(
1100            license.claims,
1101            Some(Claims {
1102                iss: "https://www.apollographql.com/".to_string(),
1103                sub: "apollo".to_string(),
1104                aud: OneOrMany::One(Audience::SelfHosted),
1105                warn_at: UNIX_EPOCH + Duration::from_secs(1676808000),
1106                halt_at: UNIX_EPOCH + Duration::from_secs(1678017600),
1107                tps: Default::default(),
1108                allowed_features: Default::default()
1109            }),
1110        );
1111    }
1112
1113    #[test]
1114    fn test_license_parse_with_whitespace() {
1115        let license = License::from_str("   eyJhbGciOiJFZERTQSJ9.eyJpc3MiOiJodHRwczovL3d3dy5hcG9sbG9ncmFwaHFsLmNvbS8iLCJzdWIiOiJhcG9sbG8iLCJhdWQiOiJTRUxGX0hPU1RFRCIsIndhcm5BdCI6MTY3NjgwODAwMCwiaGFsdEF0IjoxNjc4MDE3NjAwfQ.tXexfjZ2SQeqSwkWQ7zD4XBoxS_Hc5x7tSNJ3ln-BCL_GH7i3U9hsIgdRQTczCAjA_jjk34w39DeSV0nTc5WBw\n ").expect("must be able to decode JWT"); // gitleaks:allow
1116        assert_eq!(
1117            license.claims,
1118            Some(Claims {
1119                iss: "https://www.apollographql.com/".to_string(),
1120                sub: "apollo".to_string(),
1121                aud: OneOrMany::One(Audience::SelfHosted),
1122                warn_at: UNIX_EPOCH + Duration::from_secs(1676808000),
1123                halt_at: UNIX_EPOCH + Duration::from_secs(1678017600),
1124                tps: Default::default(),
1125                allowed_features: Default::default()
1126            }),
1127        );
1128    }
1129
1130    #[test]
1131    fn test_license_parse_fail() {
1132        License::from_str("invalid").expect_err("jwt must fail parse");
1133    }
1134
1135    #[test]
1136    fn claims_serde() {
1137        serde_json::from_value::<Claims>(json!({
1138            "iss": "Issuer",
1139            "sub": "Subject",
1140            "aud": "CLOUD",
1141            "warnAt": 122,
1142            "haltAt": 123,
1143        }))
1144        .expect("json must deserialize");
1145
1146        serde_json::from_value::<Claims>(json!({
1147            "iss": "Issuer",
1148            "sub": "Subject",
1149            "aud": ["CLOUD", "SELF_HOSTED"],
1150            "warnAt": 122,
1151            "haltAt": 123,
1152        }))
1153        .expect("json must deserialize");
1154
1155        serde_json::from_value::<Claims>(json!({
1156            "iss": "Issuer",
1157            "sub": "Subject",
1158            "aud": "OFFLINE",
1159            "warnAt": 122,
1160            "haltAt": 123,
1161        }))
1162        .expect("json must deserialize");
1163
1164        serde_json::from_value::<Claims>(json!({
1165            "iss": "Issuer",
1166            "sub": "Subject",
1167            "aud": "OFFLINE",
1168            "warnAt": 122,
1169            "haltAt": 123,
1170            "allowedFeatures": ["SUBSCRIPTIONS", "ENTITY_CACHING"]
1171        }))
1172        .expect("json must deserialize");
1173    }
1174
1175    #[test]
1176    fn progressive_override_available_to_oss() {
1177        let report = check(
1178            include_str!("testdata/oss.router.yaml"),
1179            include_str!("testdata/progressive_override.graphql"),
1180            LicenseState::default(),
1181        );
1182
1183        // progressive override is available for oss
1184        assert!(
1185            report.restricted_schema_in_use.is_empty(),
1186            "shouldn't have found restricted features"
1187        );
1188    }
1189
1190    #[test]
1191    fn set_context() {
1192        let report = check(
1193            include_str!("testdata/oss.router.yaml"),
1194            include_str!("testdata/set_context.graphql"),
1195            LicenseState::default(),
1196        );
1197
1198        assert!(
1199            report.restricted_schema_in_use.is_empty(),
1200            "shouldn't have found restricted features"
1201        );
1202    }
1203
1204    #[test]
1205    fn progressive_override_with_renamed_join_spec() {
1206        let report = check(
1207            include_str!("testdata/oss.router.yaml"),
1208            include_str!("testdata/progressive_override_renamed_join.graphql"),
1209            LicenseState::default(),
1210        );
1211
1212        assert!(
1213            report.restricted_schema_in_use.is_empty(),
1214            "shouldn't have found restricted features"
1215        );
1216    }
1217
1218    #[test]
1219    fn schema_enforcement_spec_version_in_range() {
1220        let report = check(
1221            include_str!("testdata/oss.router.yaml"),
1222            include_str!("testdata/schema_enforcement_spec_version_in_range.graphql"),
1223            LicenseState::default(),
1224        );
1225
1226        assert!(
1227            !report.restricted_schema_in_use.is_empty(),
1228            "should have found restricted features"
1229        );
1230        assert_snapshot!(report.to_string());
1231    }
1232
1233    #[test]
1234    fn schema_enforcement_spec_version_out_of_range() {
1235        let report = check(
1236            include_str!("testdata/oss.router.yaml"),
1237            include_str!("testdata/schema_enforcement_spec_version_out_of_range.graphql"),
1238            LicenseState::default(),
1239        );
1240
1241        assert!(
1242            report.restricted_schema_in_use.is_empty(),
1243            "shouldn't have found restricted features"
1244        );
1245    }
1246
1247    #[test]
1248    fn schema_enforcement_directive_arg_version_in_range() {
1249        let report = check(
1250            include_str!("testdata/oss.router.yaml"),
1251            include_str!("testdata/schema_enforcement_directive_arg_version_in_range.graphql"),
1252            LicenseState::default(),
1253        );
1254
1255        assert!(
1256            report.restricted_schema_in_use.is_empty(),
1257            "shouldn't have found restricted features"
1258        );
1259    }
1260
1261    #[test]
1262    fn schema_enforcement_directive_arg_version_out_of_range() {
1263        let report = check(
1264            include_str!("testdata/oss.router.yaml"),
1265            include_str!("testdata/schema_enforcement_directive_arg_version_out_of_range.graphql"),
1266            LicenseState::default(),
1267        );
1268
1269        assert!(
1270            report.restricted_schema_in_use.is_empty(),
1271            "shouldn't have found restricted features"
1272        );
1273    }
1274
1275    #[test]
1276    fn schema_enforcement_connectors() {
1277        let report = check(
1278            include_str!("testdata/oss.router.yaml"),
1279            include_str!("testdata/schema_enforcement_connectors.graphql"),
1280            LicenseState::default(),
1281        );
1282
1283        assert!(
1284            report.restricted_schema_in_use.is_empty(),
1285            "shouldn't have found restricted connect feature"
1286        );
1287    }
1288}