apollo_router/plugins/authentication/
mod.rs

1//! Authentication plugin
2
3use std::collections::HashMap;
4use std::ops::ControlFlow;
5use std::str::FromStr;
6use std::sync::Arc;
7use std::time::Duration;
8
9use error::AuthenticationError;
10use error::Error;
11use http::HeaderName;
12use http::HeaderValue;
13use http::StatusCode;
14use http::header;
15use jsonwebtoken::Algorithm;
16use jsonwebtoken::decode_header;
17use once_cell::sync::Lazy;
18use reqwest::Client;
19use schemars::JsonSchema;
20use serde::Deserialize;
21use serde::Serialize;
22use tower::BoxError;
23use tower::ServiceBuilder;
24use tower::ServiceExt;
25use url::Url;
26
27use self::jwks::JwksManager;
28use self::subgraph::SigningParams;
29use self::subgraph::SigningParamsConfig;
30use self::subgraph::SubgraphAuth;
31use crate::graphql;
32use crate::layers::ServiceBuilderExt;
33use crate::plugin::PluginInit;
34use crate::plugin::PluginPrivate;
35use crate::plugin::serde::deserialize_header_name;
36use crate::plugin::serde::deserialize_header_value;
37use crate::plugins::authentication::connector::ConnectorAuth;
38use crate::plugins::authentication::error::ErrorContext;
39use crate::plugins::authentication::jwks::Audiences;
40use crate::plugins::authentication::jwks::Issuers;
41use crate::plugins::authentication::jwks::JwksConfig;
42use crate::plugins::authentication::subgraph::make_signing_params;
43use crate::services::APPLICATION_JSON_HEADER_VALUE;
44use crate::services::connector_service::ConnectorSourceRef;
45use crate::services::router;
46
47pub(crate) mod jwks;
48
49pub(crate) mod connector;
50
51pub(crate) mod subgraph;
52
53mod error;
54#[cfg(test)]
55mod tests;
56
57pub(crate) const AUTHENTICATION_SPAN_NAME: &str = "authentication_plugin";
58pub(crate) const APOLLO_AUTHENTICATION_JWT_CLAIMS: &str = "apollo::authentication::jwt_claims";
59const HEADER_TOKEN_TRUNCATED: &str = "(truncated)";
60
61const DEFAULT_AUTHENTICATION_NETWORK_TIMEOUT: Duration = Duration::from_secs(15);
62const DEFAULT_AUTHENTICATION_DOWNLOAD_INTERVAL: Duration = Duration::from_secs(60);
63
64static CLIENT: Lazy<Result<Client, BoxError>> = Lazy::new(|| Ok(Client::new()));
65
66struct Router {
67    configuration: JWTConf,
68    jwks_manager: JwksManager,
69}
70
71struct AuthenticationPlugin {
72    router: Option<Router>,
73    subgraph: Option<SubgraphAuth>,
74    connector: Option<ConnectorAuth>,
75}
76
77#[derive(Clone, Debug, Deserialize, JsonSchema, PartialEq)]
78enum OnError {
79    Continue,
80    Error,
81}
82
83impl Default for OnError {
84    fn default() -> Self {
85        Self::Error
86    }
87}
88
89#[derive(Clone, Debug, Deserialize, JsonSchema, serde_derive_default::Default)]
90#[serde(deny_unknown_fields)]
91struct JWTConf {
92    /// List of JWKS used to verify tokens
93    jwks: Vec<JwksConf>,
94    /// HTTP header expected to contain JWT
95    #[serde(default = "default_header_name")]
96    header_name: String,
97    /// Header value prefix
98    #[serde(default = "default_header_value_prefix")]
99    header_value_prefix: String,
100    /// Whether to ignore any mismatched prefixes
101    #[serde(default)]
102    ignore_other_prefixes: bool,
103    /// Alternative sources to extract the JWT
104    #[serde(default)]
105    sources: Vec<Source>,
106    /// Control the behavior when an error occurs during the authentication process.
107    ///
108    /// Defaults to `Error`. When set to `Continue`, requests that fail JWT authentication will
109    /// continue to be processed by the router, but without the JWT claims in the context. When set
110    /// to `Error`, requests that fail JWT authentication will be rejected with a HTTP 403 error.
111    #[serde(default)]
112    on_error: OnError,
113}
114
115#[derive(Clone, Debug, Deserialize, JsonSchema)]
116#[serde(deny_unknown_fields)]
117struct JwksConf {
118    /// Retrieve the JWK Set
119    url: String,
120    /// Polling interval for each JWKS endpoint in human-readable format; defaults to 60s
121    #[serde(
122        deserialize_with = "humantime_serde::deserialize",
123        default = "default_poll_interval"
124    )]
125    #[schemars(with = "String", default = "default_poll_interval")]
126    poll_interval: Duration,
127    /// Expected issuers for tokens verified by that JWKS
128    ///
129    /// If not specified, the issuer will not be checked.
130    issuers: Option<Issuers>,
131    /// Expected audiences for tokens verified by that JWKS
132    ///
133    /// If not specified, the audience will not be checked.
134    audiences: Option<Audiences>,
135    /// List of accepted algorithms. Possible values are `HS256`, `HS384`, `HS512`, `ES256`, `ES384`, `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`, `EdDSA`
136    #[schemars(with = "Option<Vec<String>>", default)]
137    #[serde(default)]
138    algorithms: Option<Vec<Algorithm>>,
139    /// List of headers to add to the JWKS request
140    #[serde(default)]
141    headers: Vec<Header>,
142}
143
144#[derive(Clone, Debug, JsonSchema, Deserialize)]
145#[serde(rename_all = "snake_case", deny_unknown_fields)]
146/// Insert a header
147struct Header {
148    /// The name of the header
149    #[schemars(with = "String")]
150    #[serde(deserialize_with = "deserialize_header_name")]
151    name: HeaderName,
152
153    /// The value for the header
154    #[schemars(with = "String")]
155    #[serde(deserialize_with = "deserialize_header_value")]
156    value: HeaderValue,
157}
158
159#[derive(Clone, Debug, Deserialize, JsonSchema)]
160#[serde(deny_unknown_fields, rename_all = "lowercase", tag = "type")]
161enum Source {
162    Header {
163        /// HTTP header expected to contain JWT
164        #[serde(default = "default_header_name")]
165        name: String,
166        /// Header value prefix
167        #[serde(default = "default_header_value_prefix")]
168        value_prefix: String,
169    },
170    Cookie {
171        /// Name of the cookie containing the JWT
172        name: String,
173    },
174}
175
176/// Authentication
177#[derive(Clone, Debug, Default, Deserialize, JsonSchema)]
178#[serde(deny_unknown_fields)]
179#[schemars(rename = "AuthenticationConfig")]
180struct Conf {
181    /// Router configuration
182    router: Option<RouterConf>,
183    /// Subgraph configuration
184    subgraph: Option<subgraph::Config>,
185    /// Connector configuration
186    connector: Option<connector::Config>,
187}
188
189// We may support additional authentication mechanisms in future, so all
190// configuration (which is currently JWT specific) is isolated to the
191// JWTConf structure.
192#[derive(Clone, Debug, Default, Deserialize, JsonSchema)]
193#[serde(deny_unknown_fields)]
194#[schemars(rename = "AuthenticationRouterConfig")]
195struct RouterConf {
196    /// The JWT configuration
197    jwt: JWTConf,
198}
199
200fn default_header_name() -> String {
201    header::AUTHORIZATION.to_string()
202}
203
204fn default_header_value_prefix() -> String {
205    "Bearer".to_string()
206}
207
208fn default_poll_interval() -> Duration {
209    DEFAULT_AUTHENTICATION_DOWNLOAD_INTERVAL
210}
211
212#[async_trait::async_trait]
213impl PluginPrivate for AuthenticationPlugin {
214    type Config = Conf;
215
216    async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
217        let subgraph = Self::init_subgraph(&init).await?;
218        let router = Self::init_router(&init).await?;
219        let connector = Self::init_connector(init).await?;
220
221        Ok(Self {
222            router,
223            subgraph,
224            connector,
225        })
226    }
227
228    fn router_service(&self, service: router::BoxService) -> router::BoxService {
229        // Return without layering if no router config was defined
230        let Some(router_config) = &self.router else {
231            return service;
232        };
233
234        fn authentication_service_span() -> impl Fn(&router::Request) -> tracing::Span + Clone {
235            move |_request: &router::Request| {
236                tracing::info_span!(
237                    AUTHENTICATION_SPAN_NAME,
238                    "authentication service" = stringify!(router::Request),
239                    "otel.kind" = "INTERNAL"
240                )
241            }
242        }
243
244        let jwks_manager = router_config.jwks_manager.clone();
245        let configuration = router_config.configuration.clone();
246
247        ServiceBuilder::new()
248            .instrument(authentication_service_span())
249            .checkpoint(move |request: router::Request| {
250                Ok(authenticate(&configuration, &jwks_manager, request))
251            })
252            .service(service)
253            .boxed()
254    }
255
256    fn subgraph_service(
257        &self,
258        name: &str,
259        service: crate::services::subgraph::BoxService,
260    ) -> crate::services::subgraph::BoxService {
261        // Return without layering if no subgraph config was defined
262        let Some(subgraph) = &self.subgraph else {
263            return service;
264        };
265
266        subgraph.subgraph_service(name, service)
267    }
268
269    fn connector_request_service(
270        &self,
271        service: crate::services::connector::request_service::BoxService,
272        _: String,
273    ) -> crate::services::connector::request_service::BoxService {
274        // Return without layering if no connector config was defined
275        let Some(connector_auth) = &self.connector else {
276            return service;
277        };
278
279        connector_auth.connector_request_service(service)
280    }
281}
282
283impl AuthenticationPlugin {
284    async fn init_subgraph(init: &PluginInit<Conf>) -> Result<Option<SubgraphAuth>, BoxError> {
285        // if no subgraph config was defined, then return early
286        let Some(subgraph_conf) = init.config.subgraph.clone() else {
287            return Ok(None);
288        };
289
290        let all = if let Some(config) = &subgraph_conf.all {
291            Some(Arc::new(make_signing_params(config, "all").await?))
292        } else {
293            None
294        };
295
296        let mut subgraphs: HashMap<String, Arc<SigningParamsConfig>> = Default::default();
297        for (subgraph_name, config) in &subgraph_conf.subgraphs {
298            subgraphs.insert(
299                subgraph_name.clone(),
300                Arc::new(make_signing_params(config, subgraph_name.as_str()).await?),
301            );
302        }
303
304        Ok(Some(SubgraphAuth {
305            signing_params: Arc::new(SigningParams { all, subgraphs }),
306        }))
307    }
308
309    async fn init_router(init: &PluginInit<Conf>) -> Result<Option<Router>, BoxError> {
310        // if no router config was defined, then return early
311        let Some(mut router_conf) = init.config.router.clone() else {
312            return Ok(None);
313        };
314
315        if router_conf
316            .jwt
317            .header_value_prefix
318            .as_bytes()
319            .iter()
320            .any(u8::is_ascii_whitespace)
321        {
322            return Err(Error::BadHeaderValuePrefix.into());
323        }
324
325        for source in &router_conf.jwt.sources {
326            if let Source::Header { value_prefix, .. } = source
327                && value_prefix.as_bytes().iter().any(u8::is_ascii_whitespace)
328            {
329                return Err(Error::BadHeaderValuePrefix.into());
330            }
331        }
332
333        router_conf.jwt.sources.insert(
334            0,
335            Source::Header {
336                name: router_conf.jwt.header_name.clone(),
337                value_prefix: router_conf.jwt.header_value_prefix.clone(),
338            },
339        );
340
341        let mut list = vec![];
342        for jwks_conf in &router_conf.jwt.jwks {
343            let url: Url = Url::from_str(jwks_conf.url.as_str())?;
344            list.push(JwksConfig {
345                url,
346                issuers: jwks_conf.issuers.clone(),
347                audiences: jwks_conf.audiences.clone(),
348                algorithms: jwks_conf
349                    .algorithms
350                    .as_ref()
351                    .map(|algs| algs.iter().cloned().collect()),
352                poll_interval: jwks_conf.poll_interval,
353                headers: jwks_conf.headers.clone(),
354            });
355        }
356
357        let jwks_manager = JwksManager::new(list).await?;
358
359        Ok(Some(Router {
360            configuration: router_conf.jwt,
361            jwks_manager,
362        }))
363    }
364
365    async fn init_connector(init: PluginInit<Conf>) -> Result<Option<ConnectorAuth>, BoxError> {
366        // if no connector config was defined, then return early
367        let Some(connector_conf) = init.config.connector.clone() else {
368            return Ok(None);
369        };
370
371        let mut signing_params: HashMap<ConnectorSourceRef, Arc<SigningParamsConfig>> =
372            Default::default();
373        for (s, source_config) in connector_conf.sources {
374            let source_ref: ConnectorSourceRef = s.parse()?;
375            signing_params.insert(
376                source_ref.clone(),
377                make_signing_params(&source_config, &source_ref.subgraph_name)
378                    .await
379                    .map(Arc::new)?,
380            );
381        }
382
383        Ok(Some(ConnectorAuth {
384            signing_params: Arc::new(signing_params),
385        }))
386    }
387}
388
389#[derive(Debug, Serialize, Deserialize)]
390enum JwtStatus {
391    Failure {
392        r#type: String,
393        name: String,
394        error: ErrorContext,
395    },
396    Success {
397        r#type: String,
398        name: String,
399    },
400}
401
402impl JwtStatus {
403    fn new_failure(source: Option<&Source>, error_context: ErrorContext) -> Self {
404        let (r#type, name) = match source {
405            Some(Source::Header { name, .. }) => ("header", name.as_str()),
406            Some(Source::Cookie { name }) => ("cookie", name.as_str()),
407            None => ("unknown", "unknown"),
408        };
409
410        Self::Failure {
411            r#type: r#type.into(),
412            name: name.into(),
413            error: error_context,
414        }
415    }
416
417    fn new_success(source: Option<&Source>) -> Self {
418        match source {
419            Some(Source::Header { name, .. }) => Self::Success {
420                r#type: "header".into(),
421                name: name.into(),
422            },
423            Some(Source::Cookie { name }) => Self::Success {
424                r#type: "cookie".into(),
425                name: name.into(),
426            },
427            None => Self::Success {
428                r#type: "unknown".into(),
429                name: "unknown".into(),
430            },
431        }
432    }
433
434    #[cfg(test)]
435    /// Returns the error context if the status is a failure; Otherwise, returns None.
436    fn error(&self) -> Option<&ErrorContext> {
437        match self {
438            Self::Failure { error, .. } => Some(error),
439            _ => None,
440        }
441    }
442}
443
444const JWT_CONTEXT_KEY: &str = "apollo::authentication::jwt_status";
445
446fn authenticate(
447    config: &JWTConf,
448    jwks_manager: &JwksManager,
449    request: router::Request,
450) -> ControlFlow<router::Response, router::Request> {
451    // We are going to do a lot of similar checking so let's define a local function
452    // to help reduce repetition
453    fn failure_message(
454        request: router::Request,
455        config: &JWTConf,
456        error: AuthenticationError,
457        status: StatusCode,
458        source: Option<&Source>,
459    ) -> ControlFlow<router::Response, router::Request> {
460        // This is a metric and will not appear in the logs
461        let failed = true;
462        increment_jwt_counter_metric(failed);
463
464        tracing::info!(message = %error, "jwt authentication failure");
465
466        let _ = request.context.insert_json_value(
467            JWT_CONTEXT_KEY,
468            serde_json_bytes::json!(JwtStatus::new_failure(source, error.as_context_object())),
469        );
470
471        if config.on_error == OnError::Error {
472            let response = router::Response::infallible_builder()
473                .error(
474                    graphql::Error::builder()
475                        .message(error.to_string())
476                        .extension_code("AUTH_ERROR")
477                        .build(),
478                )
479                .status_code(status)
480                .header(header::CONTENT_TYPE, APPLICATION_JSON_HEADER_VALUE.clone())
481                .context(request.context)
482                .build();
483
484            ControlFlow::Break(response)
485        } else {
486            ControlFlow::Continue(request)
487        }
488    }
489
490    /// This is the documented metric
491    fn increment_jwt_counter_metric(failed: bool) {
492        u64_counter!(
493            "apollo.router.operations.authentication.jwt",
494            "Number of requests with JWT authentication",
495            1,
496            authentication.jwt.failed = failed
497        );
498    }
499
500    let mut jwt = None;
501    let mut source_of_extracted_jwt = None;
502    for source in &config.sources {
503        let extracted_jwt = jwks::extract_jwt(
504            source,
505            config.ignore_other_prefixes,
506            request.router_request.headers(),
507        );
508
509        match extracted_jwt {
510            None => continue,
511            Some(Ok(extracted_jwt)) => {
512                source_of_extracted_jwt = Some(source);
513                jwt = Some(extracted_jwt);
514                break;
515            }
516            Some(Err(error)) => {
517                return failure_message(
518                    request,
519                    config,
520                    error,
521                    StatusCode::BAD_REQUEST,
522                    Some(source),
523                );
524            }
525        }
526    }
527
528    let jwt = match jwt {
529        Some(jwt) => jwt,
530        None => return ControlFlow::Continue(request),
531    };
532
533    // Try to create a valid header to work with
534    let jwt_header = match decode_header(jwt) {
535        Ok(h) => h,
536        Err(e) => {
537            // Don't reflect the jwt on error, just reply with a fixed
538            // error message.
539            return failure_message(
540                request,
541                config,
542                AuthenticationError::InvalidHeader(HEADER_TOKEN_TRUNCATED.to_owned(), e),
543                StatusCode::BAD_REQUEST,
544                source_of_extracted_jwt,
545            );
546        }
547    };
548
549    // Extract our search criteria from our jwt
550    let criteria = jwks::JWTCriteria {
551        kid: jwt_header.kid,
552        alg: jwt_header.alg,
553    };
554
555    // Search our list of JWKS to find the kid and process it
556    // Note: This will search through JWKS in the order in which they are defined
557    // in configuration.
558    if let Some(keys) = jwks::search_jwks(jwks_manager, &criteria) {
559        let (issuers, audiences, token_data) = match jwks::decode_jwt(jwt, keys, criteria) {
560            Ok(data) => data,
561            Err((auth_error, status_code)) => {
562                return failure_message(
563                    request,
564                    config,
565                    auth_error,
566                    status_code,
567                    source_of_extracted_jwt,
568                );
569            }
570        };
571
572        if let Some(configured_issuers) = issuers
573            && let Some(token_issuer) = token_data
574                .claims
575                .as_object()
576                .and_then(|o| o.get("iss"))
577                .and_then(|value| value.as_str())
578            && !configured_issuers.contains(token_issuer)
579        {
580            let mut issuers_for_error: Vec<String> = configured_issuers.into_iter().collect();
581            issuers_for_error.sort(); // done to maintain consistent ordering in error message
582            return failure_message(
583                request,
584                config,
585                AuthenticationError::InvalidIssuer {
586                    expected: issuers_for_error
587                        .iter()
588                        .map(|issuer| issuer.to_string())
589                        .collect::<Vec<_>>()
590                        .join(", "),
591                    token: token_issuer.to_string(),
592                },
593                StatusCode::INTERNAL_SERVER_ERROR,
594                source_of_extracted_jwt,
595            );
596        }
597
598        if let Some(configured_audiences) = audiences {
599            let maybe_token_audiences = token_data.claims.as_object().and_then(|o| o.get("aud"));
600            let Some(maybe_token_audiences) = maybe_token_audiences else {
601                let mut audiences_for_error: Vec<String> =
602                    configured_audiences.into_iter().collect();
603                audiences_for_error.sort(); // done to maintain consistent ordering in error message
604                return failure_message(
605                    request,
606                    config,
607                    AuthenticationError::InvalidAudience {
608                        expected: audiences_for_error
609                            .iter()
610                            .map(|audience| audience.to_string())
611                            .collect::<Vec<_>>()
612                            .join(", "),
613                        actual: "<none>".to_string(),
614                    },
615                    StatusCode::UNAUTHORIZED,
616                    source_of_extracted_jwt,
617                );
618            };
619
620            if let Some(token_audience) = maybe_token_audiences.as_str() {
621                if !configured_audiences.contains(token_audience) {
622                    let mut audiences_for_error: Vec<String> =
623                        configured_audiences.into_iter().collect();
624                    audiences_for_error.sort(); // done to maintain consistent ordering in error message
625                    return failure_message(
626                        request,
627                        config,
628                        AuthenticationError::InvalidAudience {
629                            expected: audiences_for_error
630                                .iter()
631                                .map(|audience| audience.to_string())
632                                .collect::<Vec<_>>()
633                                .join(", "),
634                            actual: token_audience.to_string(),
635                        },
636                        StatusCode::UNAUTHORIZED,
637                        source_of_extracted_jwt,
638                    );
639                }
640            } else {
641                // If the token has incorrectly configured audiences, we cannot validate it against
642                // the configured audiences.
643                let mut audiences_for_error: Vec<String> =
644                    configured_audiences.into_iter().collect();
645                audiences_for_error.sort(); // done to maintain consistent ordering in error message
646                return failure_message(
647                    request,
648                    config,
649                    AuthenticationError::InvalidAudience {
650                        expected: audiences_for_error
651                            .iter()
652                            .map(|audience| audience.to_string())
653                            .collect::<Vec<_>>()
654                            .join(", "),
655                        actual: maybe_token_audiences.to_string(),
656                    },
657                    StatusCode::UNAUTHORIZED,
658                    source_of_extracted_jwt,
659                );
660            }
661        }
662
663        if let Err(e) = request
664            .context
665            .insert(APOLLO_AUTHENTICATION_JWT_CLAIMS, token_data.claims.clone())
666        {
667            return failure_message(
668                request,
669                config,
670                AuthenticationError::CannotInsertClaimsIntoContext(e),
671                StatusCode::INTERNAL_SERVER_ERROR,
672                source_of_extracted_jwt,
673            );
674        }
675        // This is a metric and will not appear in the logs
676        //
677        // Apparently intended to be `apollo.router.operations.authentication.jwt` like above,
678        // but has existed for two years with a buggy name. Keep it for now.
679        u64_counter!(
680            "apollo.router.operations.jwt",
681            "Number of requests with JWT successful authentication (deprecated, \
682                use `apollo.router.operations.authentication.jwt` \
683                with `authentication.jwt.failed = false` instead)",
684            1
685        );
686        // Use the fixed name too:
687        let failed = false;
688        increment_jwt_counter_metric(failed);
689
690        let _ = request.context.insert_json_value(
691            JWT_CONTEXT_KEY,
692            serde_json_bytes::json!(JwtStatus::new_success(source_of_extracted_jwt)),
693        );
694
695        return ControlFlow::Continue(request);
696    }
697
698    // We can't find a key to process this JWT.
699    let err = criteria.kid.map_or_else(
700        || AuthenticationError::CannotFindSuitableKey(criteria.alg, None),
701        AuthenticationError::CannotFindKID,
702    );
703
704    failure_message(
705        request,
706        config,
707        err,
708        StatusCode::UNAUTHORIZED,
709        source_of_extracted_jwt,
710    )
711}
712
713// This macro allows us to use it in our plugin registry!
714// register_plugin takes a group name, and a plugin name.
715//
716// In order to keep the plugin names consistent,
717// we use using the `Reverse domain name notation`
718register_private_plugin!("apollo", "authentication", AuthenticationPlugin);