1use std::collections::HashMap;
4use std::ops::ControlFlow;
5use std::str::FromStr;
6use std::sync::Arc;
7use std::time::Duration;
8
9use error::AuthenticationError;
10use error::Error;
11use http::HeaderName;
12use http::HeaderValue;
13use http::StatusCode;
14use http::header;
15use jsonwebtoken::Algorithm;
16use jsonwebtoken::decode_header;
17use once_cell::sync::Lazy;
18use reqwest::Client;
19use schemars::JsonSchema;
20use serde::Deserialize;
21use serde::Serialize;
22use tower::BoxError;
23use tower::ServiceBuilder;
24use tower::ServiceExt;
25use url::Url;
26
27use self::jwks::JwksManager;
28use self::subgraph::SigningParams;
29use self::subgraph::SigningParamsConfig;
30use self::subgraph::SubgraphAuth;
31use crate::graphql;
32use crate::layers::ServiceBuilderExt;
33use crate::plugin::PluginInit;
34use crate::plugin::PluginPrivate;
35use crate::plugin::serde::deserialize_header_name;
36use crate::plugin::serde::deserialize_header_value;
37use crate::plugins::authentication::connector::ConnectorAuth;
38use crate::plugins::authentication::error::ErrorContext;
39use crate::plugins::authentication::jwks::Audiences;
40use crate::plugins::authentication::jwks::Issuers;
41use crate::plugins::authentication::jwks::JwksConfig;
42use crate::plugins::authentication::subgraph::make_signing_params;
43use crate::services::APPLICATION_JSON_HEADER_VALUE;
44use crate::services::connector_service::ConnectorSourceRef;
45use crate::services::router;
46
47pub(crate) mod jwks;
48
49pub(crate) mod connector;
50
51pub(crate) mod subgraph;
52
53mod error;
54#[cfg(test)]
55mod tests;
56
57pub(crate) const AUTHENTICATION_SPAN_NAME: &str = "authentication_plugin";
58pub(crate) const APOLLO_AUTHENTICATION_JWT_CLAIMS: &str = "apollo::authentication::jwt_claims";
59const HEADER_TOKEN_TRUNCATED: &str = "(truncated)";
60
61const DEFAULT_AUTHENTICATION_NETWORK_TIMEOUT: Duration = Duration::from_secs(15);
62const DEFAULT_AUTHENTICATION_DOWNLOAD_INTERVAL: Duration = Duration::from_secs(60);
63
64static CLIENT: Lazy<Result<Client, BoxError>> = Lazy::new(|| Ok(Client::new()));
65
66struct Router {
67 configuration: JWTConf,
68 jwks_manager: JwksManager,
69}
70
71struct AuthenticationPlugin {
72 router: Option<Router>,
73 subgraph: Option<SubgraphAuth>,
74 connector: Option<ConnectorAuth>,
75}
76
77#[derive(Clone, Debug, Deserialize, JsonSchema, PartialEq)]
78enum OnError {
79 Continue,
80 Error,
81}
82
83impl Default for OnError {
84 fn default() -> Self {
85 Self::Error
86 }
87}
88
89#[derive(Clone, Debug, Deserialize, JsonSchema, serde_derive_default::Default)]
90#[serde(deny_unknown_fields)]
91struct JWTConf {
92 jwks: Vec<JwksConf>,
94 #[serde(default = "default_header_name")]
96 header_name: String,
97 #[serde(default = "default_header_value_prefix")]
99 header_value_prefix: String,
100 #[serde(default)]
102 ignore_other_prefixes: bool,
103 #[serde(default)]
105 sources: Vec<Source>,
106 #[serde(default)]
112 on_error: OnError,
113}
114
115#[derive(Clone, Debug, Deserialize, JsonSchema)]
116#[serde(deny_unknown_fields)]
117struct JwksConf {
118 url: String,
120 #[serde(
122 deserialize_with = "humantime_serde::deserialize",
123 default = "default_poll_interval"
124 )]
125 #[schemars(with = "String", default = "default_poll_interval")]
126 poll_interval: Duration,
127 issuers: Option<Issuers>,
131 audiences: Option<Audiences>,
135 #[schemars(with = "Option<Vec<String>>", default)]
137 #[serde(default)]
138 algorithms: Option<Vec<Algorithm>>,
139 #[serde(default)]
141 headers: Vec<Header>,
142}
143
144#[derive(Clone, Debug, JsonSchema, Deserialize)]
145#[serde(rename_all = "snake_case", deny_unknown_fields)]
146struct Header {
148 #[schemars(with = "String")]
150 #[serde(deserialize_with = "deserialize_header_name")]
151 name: HeaderName,
152
153 #[schemars(with = "String")]
155 #[serde(deserialize_with = "deserialize_header_value")]
156 value: HeaderValue,
157}
158
159#[derive(Clone, Debug, Deserialize, JsonSchema)]
160#[serde(deny_unknown_fields, rename_all = "lowercase", tag = "type")]
161enum Source {
162 Header {
163 #[serde(default = "default_header_name")]
165 name: String,
166 #[serde(default = "default_header_value_prefix")]
168 value_prefix: String,
169 },
170 Cookie {
171 name: String,
173 },
174}
175
176#[derive(Clone, Debug, Default, Deserialize, JsonSchema)]
178#[serde(deny_unknown_fields)]
179#[schemars(rename = "AuthenticationConfig")]
180struct Conf {
181 router: Option<RouterConf>,
183 subgraph: Option<subgraph::Config>,
185 connector: Option<connector::Config>,
187}
188
189#[derive(Clone, Debug, Default, Deserialize, JsonSchema)]
193#[serde(deny_unknown_fields)]
194#[schemars(rename = "AuthenticationRouterConfig")]
195struct RouterConf {
196 jwt: JWTConf,
198}
199
200fn default_header_name() -> String {
201 header::AUTHORIZATION.to_string()
202}
203
204fn default_header_value_prefix() -> String {
205 "Bearer".to_string()
206}
207
208fn default_poll_interval() -> Duration {
209 DEFAULT_AUTHENTICATION_DOWNLOAD_INTERVAL
210}
211
212#[async_trait::async_trait]
213impl PluginPrivate for AuthenticationPlugin {
214 type Config = Conf;
215
216 async fn new(init: PluginInit<Self::Config>) -> Result<Self, BoxError> {
217 let subgraph = Self::init_subgraph(&init).await?;
218 let router = Self::init_router(&init).await?;
219 let connector = Self::init_connector(init).await?;
220
221 Ok(Self {
222 router,
223 subgraph,
224 connector,
225 })
226 }
227
228 fn router_service(&self, service: router::BoxService) -> router::BoxService {
229 let Some(router_config) = &self.router else {
231 return service;
232 };
233
234 fn authentication_service_span() -> impl Fn(&router::Request) -> tracing::Span + Clone {
235 move |_request: &router::Request| {
236 tracing::info_span!(
237 AUTHENTICATION_SPAN_NAME,
238 "authentication service" = stringify!(router::Request),
239 "otel.kind" = "INTERNAL"
240 )
241 }
242 }
243
244 let jwks_manager = router_config.jwks_manager.clone();
245 let configuration = router_config.configuration.clone();
246
247 ServiceBuilder::new()
248 .instrument(authentication_service_span())
249 .checkpoint(move |request: router::Request| {
250 Ok(authenticate(&configuration, &jwks_manager, request))
251 })
252 .service(service)
253 .boxed()
254 }
255
256 fn subgraph_service(
257 &self,
258 name: &str,
259 service: crate::services::subgraph::BoxService,
260 ) -> crate::services::subgraph::BoxService {
261 let Some(subgraph) = &self.subgraph else {
263 return service;
264 };
265
266 subgraph.subgraph_service(name, service)
267 }
268
269 fn connector_request_service(
270 &self,
271 service: crate::services::connector::request_service::BoxService,
272 _: String,
273 ) -> crate::services::connector::request_service::BoxService {
274 let Some(connector_auth) = &self.connector else {
276 return service;
277 };
278
279 connector_auth.connector_request_service(service)
280 }
281}
282
283impl AuthenticationPlugin {
284 async fn init_subgraph(init: &PluginInit<Conf>) -> Result<Option<SubgraphAuth>, BoxError> {
285 let Some(subgraph_conf) = init.config.subgraph.clone() else {
287 return Ok(None);
288 };
289
290 let all = if let Some(config) = &subgraph_conf.all {
291 Some(Arc::new(make_signing_params(config, "all").await?))
292 } else {
293 None
294 };
295
296 let mut subgraphs: HashMap<String, Arc<SigningParamsConfig>> = Default::default();
297 for (subgraph_name, config) in &subgraph_conf.subgraphs {
298 subgraphs.insert(
299 subgraph_name.clone(),
300 Arc::new(make_signing_params(config, subgraph_name.as_str()).await?),
301 );
302 }
303
304 Ok(Some(SubgraphAuth {
305 signing_params: Arc::new(SigningParams { all, subgraphs }),
306 }))
307 }
308
309 async fn init_router(init: &PluginInit<Conf>) -> Result<Option<Router>, BoxError> {
310 let Some(mut router_conf) = init.config.router.clone() else {
312 return Ok(None);
313 };
314
315 if router_conf
316 .jwt
317 .header_value_prefix
318 .as_bytes()
319 .iter()
320 .any(u8::is_ascii_whitespace)
321 {
322 return Err(Error::BadHeaderValuePrefix.into());
323 }
324
325 for source in &router_conf.jwt.sources {
326 if let Source::Header { value_prefix, .. } = source
327 && value_prefix.as_bytes().iter().any(u8::is_ascii_whitespace)
328 {
329 return Err(Error::BadHeaderValuePrefix.into());
330 }
331 }
332
333 router_conf.jwt.sources.insert(
334 0,
335 Source::Header {
336 name: router_conf.jwt.header_name.clone(),
337 value_prefix: router_conf.jwt.header_value_prefix.clone(),
338 },
339 );
340
341 let mut list = vec![];
342 for jwks_conf in &router_conf.jwt.jwks {
343 let url: Url = Url::from_str(jwks_conf.url.as_str())?;
344 list.push(JwksConfig {
345 url,
346 issuers: jwks_conf.issuers.clone(),
347 audiences: jwks_conf.audiences.clone(),
348 algorithms: jwks_conf
349 .algorithms
350 .as_ref()
351 .map(|algs| algs.iter().cloned().collect()),
352 poll_interval: jwks_conf.poll_interval,
353 headers: jwks_conf.headers.clone(),
354 });
355 }
356
357 let jwks_manager = JwksManager::new(list).await?;
358
359 Ok(Some(Router {
360 configuration: router_conf.jwt,
361 jwks_manager,
362 }))
363 }
364
365 async fn init_connector(init: PluginInit<Conf>) -> Result<Option<ConnectorAuth>, BoxError> {
366 let Some(connector_conf) = init.config.connector.clone() else {
368 return Ok(None);
369 };
370
371 let mut signing_params: HashMap<ConnectorSourceRef, Arc<SigningParamsConfig>> =
372 Default::default();
373 for (s, source_config) in connector_conf.sources {
374 let source_ref: ConnectorSourceRef = s.parse()?;
375 signing_params.insert(
376 source_ref.clone(),
377 make_signing_params(&source_config, &source_ref.subgraph_name)
378 .await
379 .map(Arc::new)?,
380 );
381 }
382
383 Ok(Some(ConnectorAuth {
384 signing_params: Arc::new(signing_params),
385 }))
386 }
387}
388
389#[derive(Debug, Serialize, Deserialize)]
390enum JwtStatus {
391 Failure {
392 r#type: String,
393 name: String,
394 error: ErrorContext,
395 },
396 Success {
397 r#type: String,
398 name: String,
399 },
400}
401
402impl JwtStatus {
403 fn new_failure(source: Option<&Source>, error_context: ErrorContext) -> Self {
404 let (r#type, name) = match source {
405 Some(Source::Header { name, .. }) => ("header", name.as_str()),
406 Some(Source::Cookie { name }) => ("cookie", name.as_str()),
407 None => ("unknown", "unknown"),
408 };
409
410 Self::Failure {
411 r#type: r#type.into(),
412 name: name.into(),
413 error: error_context,
414 }
415 }
416
417 fn new_success(source: Option<&Source>) -> Self {
418 match source {
419 Some(Source::Header { name, .. }) => Self::Success {
420 r#type: "header".into(),
421 name: name.into(),
422 },
423 Some(Source::Cookie { name }) => Self::Success {
424 r#type: "cookie".into(),
425 name: name.into(),
426 },
427 None => Self::Success {
428 r#type: "unknown".into(),
429 name: "unknown".into(),
430 },
431 }
432 }
433
434 #[cfg(test)]
435 fn error(&self) -> Option<&ErrorContext> {
437 match self {
438 Self::Failure { error, .. } => Some(error),
439 _ => None,
440 }
441 }
442}
443
444const JWT_CONTEXT_KEY: &str = "apollo::authentication::jwt_status";
445
446fn authenticate(
447 config: &JWTConf,
448 jwks_manager: &JwksManager,
449 request: router::Request,
450) -> ControlFlow<router::Response, router::Request> {
451 fn failure_message(
454 request: router::Request,
455 config: &JWTConf,
456 error: AuthenticationError,
457 status: StatusCode,
458 source: Option<&Source>,
459 ) -> ControlFlow<router::Response, router::Request> {
460 let failed = true;
462 increment_jwt_counter_metric(failed);
463
464 tracing::info!(message = %error, "jwt authentication failure");
465
466 let _ = request.context.insert_json_value(
467 JWT_CONTEXT_KEY,
468 serde_json_bytes::json!(JwtStatus::new_failure(source, error.as_context_object())),
469 );
470
471 if config.on_error == OnError::Error {
472 let response = router::Response::infallible_builder()
473 .error(
474 graphql::Error::builder()
475 .message(error.to_string())
476 .extension_code("AUTH_ERROR")
477 .build(),
478 )
479 .status_code(status)
480 .header(header::CONTENT_TYPE, APPLICATION_JSON_HEADER_VALUE.clone())
481 .context(request.context)
482 .build();
483
484 ControlFlow::Break(response)
485 } else {
486 ControlFlow::Continue(request)
487 }
488 }
489
490 fn increment_jwt_counter_metric(failed: bool) {
492 u64_counter!(
493 "apollo.router.operations.authentication.jwt",
494 "Number of requests with JWT authentication",
495 1,
496 authentication.jwt.failed = failed
497 );
498 }
499
500 let mut jwt = None;
501 let mut source_of_extracted_jwt = None;
502 for source in &config.sources {
503 let extracted_jwt = jwks::extract_jwt(
504 source,
505 config.ignore_other_prefixes,
506 request.router_request.headers(),
507 );
508
509 match extracted_jwt {
510 None => continue,
511 Some(Ok(extracted_jwt)) => {
512 source_of_extracted_jwt = Some(source);
513 jwt = Some(extracted_jwt);
514 break;
515 }
516 Some(Err(error)) => {
517 return failure_message(
518 request,
519 config,
520 error,
521 StatusCode::BAD_REQUEST,
522 Some(source),
523 );
524 }
525 }
526 }
527
528 let jwt = match jwt {
529 Some(jwt) => jwt,
530 None => return ControlFlow::Continue(request),
531 };
532
533 let jwt_header = match decode_header(jwt) {
535 Ok(h) => h,
536 Err(e) => {
537 return failure_message(
540 request,
541 config,
542 AuthenticationError::InvalidHeader(HEADER_TOKEN_TRUNCATED.to_owned(), e),
543 StatusCode::BAD_REQUEST,
544 source_of_extracted_jwt,
545 );
546 }
547 };
548
549 let criteria = jwks::JWTCriteria {
551 kid: jwt_header.kid,
552 alg: jwt_header.alg,
553 };
554
555 if let Some(keys) = jwks::search_jwks(jwks_manager, &criteria) {
559 let (issuers, audiences, token_data) = match jwks::decode_jwt(jwt, keys, criteria) {
560 Ok(data) => data,
561 Err((auth_error, status_code)) => {
562 return failure_message(
563 request,
564 config,
565 auth_error,
566 status_code,
567 source_of_extracted_jwt,
568 );
569 }
570 };
571
572 if let Some(configured_issuers) = issuers
573 && let Some(token_issuer) = token_data
574 .claims
575 .as_object()
576 .and_then(|o| o.get("iss"))
577 .and_then(|value| value.as_str())
578 && !configured_issuers.contains(token_issuer)
579 {
580 let mut issuers_for_error: Vec<String> = configured_issuers.into_iter().collect();
581 issuers_for_error.sort(); return failure_message(
583 request,
584 config,
585 AuthenticationError::InvalidIssuer {
586 expected: issuers_for_error
587 .iter()
588 .map(|issuer| issuer.to_string())
589 .collect::<Vec<_>>()
590 .join(", "),
591 token: token_issuer.to_string(),
592 },
593 StatusCode::INTERNAL_SERVER_ERROR,
594 source_of_extracted_jwt,
595 );
596 }
597
598 if let Some(configured_audiences) = audiences {
599 let maybe_token_audiences = token_data.claims.as_object().and_then(|o| o.get("aud"));
600 let Some(maybe_token_audiences) = maybe_token_audiences else {
601 let mut audiences_for_error: Vec<String> =
602 configured_audiences.into_iter().collect();
603 audiences_for_error.sort(); return failure_message(
605 request,
606 config,
607 AuthenticationError::InvalidAudience {
608 expected: audiences_for_error
609 .iter()
610 .map(|audience| audience.to_string())
611 .collect::<Vec<_>>()
612 .join(", "),
613 actual: "<none>".to_string(),
614 },
615 StatusCode::UNAUTHORIZED,
616 source_of_extracted_jwt,
617 );
618 };
619
620 if let Some(token_audience) = maybe_token_audiences.as_str() {
621 if !configured_audiences.contains(token_audience) {
622 let mut audiences_for_error: Vec<String> =
623 configured_audiences.into_iter().collect();
624 audiences_for_error.sort(); return failure_message(
626 request,
627 config,
628 AuthenticationError::InvalidAudience {
629 expected: audiences_for_error
630 .iter()
631 .map(|audience| audience.to_string())
632 .collect::<Vec<_>>()
633 .join(", "),
634 actual: token_audience.to_string(),
635 },
636 StatusCode::UNAUTHORIZED,
637 source_of_extracted_jwt,
638 );
639 }
640 } else {
641 let mut audiences_for_error: Vec<String> =
644 configured_audiences.into_iter().collect();
645 audiences_for_error.sort(); return failure_message(
647 request,
648 config,
649 AuthenticationError::InvalidAudience {
650 expected: audiences_for_error
651 .iter()
652 .map(|audience| audience.to_string())
653 .collect::<Vec<_>>()
654 .join(", "),
655 actual: maybe_token_audiences.to_string(),
656 },
657 StatusCode::UNAUTHORIZED,
658 source_of_extracted_jwt,
659 );
660 }
661 }
662
663 if let Err(e) = request
664 .context
665 .insert(APOLLO_AUTHENTICATION_JWT_CLAIMS, token_data.claims.clone())
666 {
667 return failure_message(
668 request,
669 config,
670 AuthenticationError::CannotInsertClaimsIntoContext(e),
671 StatusCode::INTERNAL_SERVER_ERROR,
672 source_of_extracted_jwt,
673 );
674 }
675 u64_counter!(
680 "apollo.router.operations.jwt",
681 "Number of requests with JWT successful authentication (deprecated, \
682 use `apollo.router.operations.authentication.jwt` \
683 with `authentication.jwt.failed = false` instead)",
684 1
685 );
686 let failed = false;
688 increment_jwt_counter_metric(failed);
689
690 let _ = request.context.insert_json_value(
691 JWT_CONTEXT_KEY,
692 serde_json_bytes::json!(JwtStatus::new_success(source_of_extracted_jwt)),
693 );
694
695 return ControlFlow::Continue(request);
696 }
697
698 let err = criteria.kid.map_or_else(
700 || AuthenticationError::CannotFindSuitableKey(criteria.alg, None),
701 AuthenticationError::CannotFindKID,
702 );
703
704 failure_message(
705 request,
706 config,
707 err,
708 StatusCode::UNAUTHORIZED,
709 source_of_extracted_jwt,
710 )
711}
712
713register_private_plugin!("apollo", "authentication", AuthenticationPlugin);