1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4use std::time::SystemTime;
5
6use aws_config::provider_config::ProviderConfig;
7use aws_credential_types::Credentials;
8use aws_credential_types::provider::ProvideCredentials;
9use aws_credential_types::provider::error::CredentialsError;
10use aws_sigv4::http_request::PayloadChecksumKind;
11use aws_sigv4::http_request::SignableBody;
12use aws_sigv4::http_request::SignableRequest;
13use aws_sigv4::http_request::SigningSettings;
14use aws_sigv4::http_request::sign;
15use aws_smithy_async::rt::sleep::TokioSleep;
16use aws_smithy_async::time::SystemTimeSource;
17use aws_smithy_http_client::tls::Provider;
18use aws_smithy_http_client::tls::rustls_provider::CryptoMode;
19use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
20use aws_smithy_runtime_api::client::identity::Identity;
21use aws_types::SdkConfig;
22use aws_types::region::Region;
23use aws_types::sdk_config::SharedCredentialsProvider;
24use http::HeaderMap;
25use http::Request;
26use parking_lot::RwLock;
27use schemars::JsonSchema;
28use serde::Deserialize;
29use serde::Serialize;
30use tokio::sync::mpsc::Sender;
31use tokio::task::JoinHandle;
32use tower::BoxError;
33use tower::ServiceBuilder;
34use tower::ServiceExt;
35
36use crate::services::SubgraphRequest;
37use crate::services::router;
38use crate::services::router::body::RouterBody;
39
40#[derive(Clone, JsonSchema, Deserialize, Serialize, Debug)]
43#[serde(rename_all = "snake_case", deny_unknown_fields)]
44pub(crate) struct AWSSigV4HardcodedConfig {
45 access_key_id: String,
47 secret_access_key: String,
49 region: String,
51 service_name: String,
53 assume_role: Option<AssumeRoleProvider>,
55}
56
57impl ProvideCredentials for AWSSigV4HardcodedConfig {
58 fn provide_credentials<'a>(
59 &'a self,
60 ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
61 where
62 Self: 'a,
63 {
64 aws_credential_types::provider::future::ProvideCredentials::ready(Ok(Credentials::new(
65 self.access_key_id.clone(),
66 self.secret_access_key.clone(),
67 None,
68 None,
69 "apollo-router",
70 )))
71 }
72}
73
74#[derive(Clone, JsonSchema, Deserialize, Serialize, Debug)]
76#[serde(deny_unknown_fields)]
77pub(crate) struct DefaultChainConfig {
78 region: String,
80 profile_name: Option<String>,
82 service_name: String,
84 assume_role: Option<AssumeRoleProvider>,
86}
87
88#[derive(Clone, JsonSchema, Deserialize, Serialize, Debug)]
90#[serde(deny_unknown_fields)]
91pub(crate) struct AssumeRoleProvider {
92 role_arn: String,
95 session_name: String,
97 external_id: Option<String>,
99}
100
101#[derive(Clone, JsonSchema, Deserialize, Serialize, Debug)]
103#[serde(rename_all = "snake_case")]
104pub(crate) enum AWSSigV4Config {
105 Hardcoded(AWSSigV4HardcodedConfig),
106 DefaultChain(DefaultChainConfig),
107}
108
109impl AWSSigV4Config {
110 async fn get_credentials_provider(&self) -> Arc<dyn ProvideCredentials> {
111 let region = self.region();
112
113 let role_provider_builder = self.assume_role().map(|assume_role_provider| {
114 let rp =
115 aws_config::sts::AssumeRoleProvider::builder(assume_role_provider.role_arn.clone())
116 .configure(
117 &SdkConfig::builder()
118 .http_client(
119 aws_smithy_http_client::Builder::new()
120 .tls_provider(Provider::Rustls(CryptoMode::Ring))
121 .build_https(),
122 )
123 .sleep_impl(TokioSleep::new())
124 .time_source(SystemTimeSource::new())
125 .behavior_version(BehaviorVersion::latest())
126 .build(),
127 )
128 .session_name(assume_role_provider.session_name.clone())
129 .region(region.clone());
130 if let Some(external_id) = &assume_role_provider.external_id {
131 rp.external_id(external_id.as_str())
132 } else {
133 rp
134 }
135 });
136
137 match self {
138 Self::DefaultChain(config) => {
139 let aws_config = credentials_chain_builder().region(region.clone());
140
141 let aws_config = if let Some(profile_name) = &config.profile_name {
142 aws_config.profile_name(profile_name.as_str())
143 } else {
144 aws_config
145 };
146
147 let chain = aws_config.build().await;
148 if let Some(assume_role_provider) = role_provider_builder {
149 Arc::new(assume_role_provider.build_from_provider(chain).await)
150 } else {
151 Arc::new(chain)
152 }
153 }
154 Self::Hardcoded(config) => {
155 let chain = credentials_chain_builder().build().await;
156 if let Some(assume_role_provider) = role_provider_builder {
157 Arc::new(assume_role_provider.build_from_provider(chain).await)
158 } else {
159 Arc::new(config.clone())
160 }
161 }
162 }
163 }
164
165 fn region(&self) -> Region {
166 let region = match self {
167 Self::DefaultChain(config) => config.region.clone(),
168 Self::Hardcoded(config) => config.region.clone(),
169 };
170 aws_types::region::Region::new(region)
171 }
172
173 fn service_name(&self) -> String {
174 match self {
175 Self::DefaultChain(config) => config.service_name.clone(),
176 Self::Hardcoded(config) => config.service_name.clone(),
177 }
178 }
179
180 fn assume_role(&self) -> Option<AssumeRoleProvider> {
181 match self {
182 Self::DefaultChain(config) => config.assume_role.clone(),
183 Self::Hardcoded(config) => config.assume_role.clone(),
184 }
185 }
186}
187
188fn credentials_chain_builder() -> aws_config::default_provider::credentials::Builder {
189 aws_config::default_provider::credentials::DefaultCredentialsChain::builder().configure(
190 ProviderConfig::default()
191 .with_http_client(
192 aws_smithy_http_client::Builder::new()
193 .tls_provider(Provider::Rustls(CryptoMode::Ring))
194 .build_https(),
195 )
196 .with_sleep_impl(TokioSleep::new())
197 .with_time_source(SystemTimeSource::new()),
198 )
199}
200
201#[derive(Clone, Debug, JsonSchema, Deserialize, Serialize)]
202#[serde(deny_unknown_fields)]
203pub(crate) enum AuthConfig {
204 #[serde(rename = "aws_sig_v4")]
205 AWSSigV4(AWSSigV4Config),
206}
207
208#[derive(Clone, Debug, Default, JsonSchema, Deserialize)]
210#[serde(rename_all = "snake_case", deny_unknown_fields)]
211#[schemars(rename = "AuthenticationSubgraphConfig")]
212pub(crate) struct Config {
213 #[serde(default)]
215 pub(crate) all: Option<AuthConfig>,
216 #[serde(default)]
217 pub(crate) subgraphs: HashMap<String, AuthConfig>,
219}
220
221#[allow(dead_code)]
222#[derive(Clone, Default)]
223pub(crate) struct SigningParams {
224 pub(crate) all: Option<Arc<SigningParamsConfig>>,
225 pub(crate) subgraphs: HashMap<String, Arc<SigningParamsConfig>>,
226}
227
228#[derive(Clone)]
229pub(crate) struct SigningParamsConfig {
230 credentials_provider: CredentialsProvider,
231 region: Region,
232 service_name: String,
233 subgraph_name: String,
234}
235
236#[derive(Clone, Debug)]
237struct CredentialsProvider {
238 credentials: Arc<RwLock<Credentials>>,
239 _credentials_updater_handle: Arc<JoinHandle<()>>,
240 #[allow(dead_code)]
241 refresh_credentials: Sender<()>,
242}
243
244const MIN_REMAINING_DURATION: Duration = std::time::Duration::from_secs(60 * 5);
246const RETRY_DURATION: Duration = std::time::Duration::from_secs(60);
248
249impl CredentialsProvider {
250 async fn from_provide_credentials(
251 provide_credentials: impl ProvideCredentials + 'static,
252 ) -> Result<Self, CredentialsError> {
253 let credentials_provider = SharedCredentialsProvider::new(provide_credentials);
254 let (sender, mut refresh_credentials_receiver) = tokio::sync::mpsc::channel(1);
255 let credentials = credentials_provider.provide_credentials().await?;
256 let mut refresh_timer = next_refresh_timer(&credentials);
257 let credentials = Arc::new(RwLock::new(credentials));
258 let c2 = credentials.clone();
259 let crp2 = credentials_provider.clone();
260 let handle = tokio::spawn(async move {
261 loop {
262 tokio::select! {
263 _ = tokio::time::sleep(refresh_timer.unwrap_or(Duration::MAX)) => {
264 refresh_timer = refresh_credentials(&crp2, &c2).await;
265 },
266 rcr = refresh_credentials_receiver.recv() => {
267 if rcr.is_some() {
268 refresh_timer = refresh_credentials(&crp2, &c2).await;
269 } else {
270 return;
271 }
272 },
273 }
274 }
275 });
276 Ok(Self {
277 _credentials_updater_handle: Arc::new(handle),
278 refresh_credentials: sender,
279 credentials,
280 })
281 }
282
283 #[allow(dead_code)]
284 pub(crate) async fn refresh_credentials(&self) {
285 let _ = self.refresh_credentials.send(()).await;
286 }
287}
288
289async fn refresh_credentials(
290 credentials_provider: &(impl ProvideCredentials + 'static),
291 credentials: &RwLock<Credentials>,
292) -> Option<Duration> {
293 match credentials_provider.provide_credentials().await {
294 Ok(new_credentials) => {
295 let mut credentials = credentials.write();
296 *credentials = new_credentials;
297 next_refresh_timer(&credentials)
298 }
299 Err(e) => {
300 tracing::warn!("authentication: couldn't refresh credentials {e}");
301 Some(RETRY_DURATION)
302 }
303 }
304}
305
306fn next_refresh_timer(credentials: &Credentials) -> Option<Duration> {
307 credentials
308 .expiry()
309 .and_then(|e| e.duration_since(SystemTime::now()).ok())
310 .and_then(|d| {
311 d.checked_sub(MIN_REMAINING_DURATION)
312 .or(Some(Duration::from_secs(0)))
313 })
314}
315
316impl ProvideCredentials for CredentialsProvider {
317 fn provide_credentials<'a>(
318 &'a self,
319 ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
320 where
321 Self: 'a,
322 {
323 aws_credential_types::provider::future::ProvideCredentials::ready(Ok(self
324 .credentials
325 .read()
326 .clone()))
327 }
328}
329
330impl SigningParamsConfig {
331 pub(crate) async fn sign(
332 &self,
333 mut req: Request<RouterBody>,
334 subgraph_name: &str,
335 ) -> Result<Request<RouterBody>, BoxError> {
336 let credentials = self.credentials().await?;
337 let builder = self.signing_params_builder(&credentials).await?;
338 let (parts, body) = req.into_parts();
339 let headers = HeaderMap::<&'static str>::default();
342 let body_bytes = router::body::into_bytes(body).await?.to_vec();
344 let signable_request = SignableRequest::new(
345 parts.method.as_str(),
346 parts.uri.to_string(),
347 headers.iter().map(|(name, value)| (name.as_str(), *value)),
348 match self.service_name.as_str() {
349 "vpc-lattice-svcs" => SignableBody::UnsignedPayload,
350 _ => SignableBody::Bytes(body_bytes.as_slice()),
351 },
352 )?;
353
354 let signing_params = builder.build().expect("all required fields set");
355
356 let (signing_instructions, _signature) = sign(signable_request, &signing_params.into())
357 .map_err(|err| {
358 increment_failure_counter(subgraph_name);
359 let error = format!("failed to sign GraphQL body for AWS SigV4: {err}");
360 tracing::error!("{}", error);
361 error
362 })?
363 .into_parts();
364 req = Request::<RouterBody>::from_parts(parts, router::body::from_bytes(body_bytes));
365 signing_instructions.apply_to_request_http1x(&mut req);
366 increment_success_counter(subgraph_name);
367 Ok(req)
368 }
369
370 pub(crate) async fn sign_empty(
372 &self,
373 mut req: Request<()>,
374 subgraph_name: &str,
375 ) -> Result<Request<()>, BoxError> {
376 let credentials = self.credentials().await?;
377 let builder = self.signing_params_builder(&credentials).await?;
378 let (parts, _) = req.into_parts();
379 let headers = HeaderMap::<&'static str>::default();
382 let signable_request = SignableRequest::new(
384 parts.method.as_str(),
385 parts.uri.to_string(),
386 headers.iter().map(|(name, value)| (name.as_str(), *value)),
387 match self.service_name.as_str() {
388 "vpc-lattice-svcs" => SignableBody::UnsignedPayload,
389 _ => SignableBody::Bytes(&[]),
390 },
391 )?;
392
393 let signing_params = builder.build().expect("all required fields set");
394
395 let (signing_instructions, _signature) = sign(signable_request, &signing_params.into())
396 .map_err(|err| {
397 increment_failure_counter(subgraph_name);
398 let error = format!("failed to sign GraphQL body for AWS SigV4: {err}");
399 tracing::error!("{}", error);
400 error
401 })?
402 .into_parts();
403 req = Request::<()>::from_parts(parts, ());
404 signing_instructions.apply_to_request_http1x(&mut req);
405 increment_success_counter(subgraph_name);
406 Ok(req)
407 }
408
409 async fn signing_params_builder<'s>(
410 &'s self,
411 identity: &'s Identity,
412 ) -> Result<aws_sigv4::sign::v4::signing_params::Builder<'s, SigningSettings>, BoxError> {
413 let settings = get_signing_settings(self);
414 let builder = aws_sigv4::sign::v4::SigningParams::builder()
415 .identity(identity)
416 .region(self.region.as_ref())
417 .name(&self.service_name)
418 .time(SystemTime::now())
419 .settings(settings);
420 Ok(builder)
421 }
422
423 async fn credentials(&self) -> Result<Identity, BoxError> {
424 self.credentials_provider
425 .provide_credentials()
426 .await
427 .map_err(|err| {
428 increment_failure_counter(self.subgraph_name.as_str());
429 let error = format!("failed to get credentials for AWS SigV4 signing: {err}");
430 tracing::error!("{}", error);
431 error.into()
432 })
433 .map(Into::into)
434 }
435}
436
437fn increment_success_counter(subgraph_name: &str) {
438 u64_counter!(
439 "apollo.router.operations.authentication.aws.sigv4",
440 "Number of subgraph requests signed with AWS SigV4",
441 1,
442 authentication.aws.sigv4.failed = false,
443 subgraph.service.name = subgraph_name.to_string()
444 );
445}
446fn increment_failure_counter(subgraph_name: &str) {
447 u64_counter!(
448 "apollo.router.operations.authentication.aws.sigv4",
449 "Number of subgraph requests signed with AWS SigV4",
450 1,
451 authentication.aws.sigv4.failed = true,
452 subgraph.service.name = subgraph_name.to_string()
453 );
454}
455
456pub(super) async fn make_signing_params(
457 config: &AuthConfig,
458 subgraph_name: &str,
459) -> Result<SigningParamsConfig, BoxError> {
460 match config {
461 AuthConfig::AWSSigV4(config) => {
462 let credentials_provider = config.get_credentials_provider().await;
463 Ok(SigningParamsConfig {
464 region: config.region(),
465 service_name: config.service_name(),
466 credentials_provider: CredentialsProvider::from_provide_credentials(
467 credentials_provider,
468 )
469 .await
470 .map_err(BoxError::from)?,
471 subgraph_name: subgraph_name.to_string(),
472 })
473 }
474 }
475}
476
477fn get_signing_settings(signing_params: &SigningParamsConfig) -> SigningSettings {
480 let mut settings = SigningSettings::default();
481 settings.payload_checksum_kind = match signing_params.service_name.as_str() {
482 "appsync" | "s3" | "vpc-lattice-svcs" => PayloadChecksumKind::XAmzSha256,
483 _ => PayloadChecksumKind::NoHeader,
484 };
485 settings
486}
487
488pub(super) struct SubgraphAuth {
489 pub(super) signing_params: Arc<SigningParams>,
490}
491
492impl SubgraphAuth {
493 pub(super) fn subgraph_service(
494 &self,
495 name: &str,
496 service: crate::services::subgraph::BoxService,
497 ) -> crate::services::subgraph::BoxService {
498 if let Some(signing_params) = self.params_for_service(name) {
499 ServiceBuilder::new()
500 .map_request(move |req: SubgraphRequest| {
501 let signing_params = signing_params.clone();
502 req.context
503 .extensions()
504 .with_lock(|lock| lock.insert(signing_params));
505 req
506 })
507 .service(service)
508 .boxed()
509 } else {
510 service
511 }
512 }
513}
514
515impl SubgraphAuth {
516 fn params_for_service(&self, service_name: &str) -> Option<Arc<SigningParamsConfig>> {
517 self.signing_params
518 .subgraphs
519 .get(service_name)
520 .cloned()
521 .or_else(|| self.signing_params.all.clone())
522 }
523}
524
525#[cfg(test)]
526mod test {
527 use std::sync::Arc;
528 use std::sync::atomic::AtomicUsize;
529 use std::sync::atomic::Ordering;
530
531 use http::header::CONTENT_LENGTH;
532 use http::header::CONTENT_TYPE;
533 use http::header::HOST;
534 use regex::Regex;
535 use tower::Service;
536
537 use super::*;
538 use crate::Context;
539 use crate::graphql::Request;
540 use crate::plugin::test::MockSubgraphService;
541 use crate::query_planner::fetch::OperationKind;
542 use crate::services::SubgraphRequest;
543 use crate::services::SubgraphResponse;
544 use crate::services::subgraph::SubgraphRequestId;
545
546 async fn test_signing_settings(service_name: &str) -> SigningSettings {
547 let params: SigningParamsConfig = make_signing_params(
548 &AuthConfig::AWSSigV4(AWSSigV4Config::Hardcoded(AWSSigV4HardcodedConfig {
549 access_key_id: "id".to_string(),
550 secret_access_key: "secret".to_string(),
551 region: "us-east-1".to_string(),
552 service_name: service_name.to_string(),
553 assume_role: None,
554 })),
555 "all",
556 )
557 .await
558 .unwrap();
559 get_signing_settings(¶ms)
560 }
561
562 #[tokio::test]
563 async fn test_get_signing_settings() {
564 assert_eq!(
565 PayloadChecksumKind::XAmzSha256,
566 test_signing_settings("s3").await.payload_checksum_kind
567 );
568 assert_eq!(
569 PayloadChecksumKind::XAmzSha256,
570 test_signing_settings("vpc-lattice-svcs")
571 .await
572 .payload_checksum_kind
573 );
574 assert_eq!(
575 PayloadChecksumKind::XAmzSha256,
576 test_signing_settings("appsync").await.payload_checksum_kind
577 );
578 assert_eq!(
579 PayloadChecksumKind::NoHeader,
580 test_signing_settings("something-else")
581 .await
582 .payload_checksum_kind
583 );
584 }
585
586 #[test]
587 fn test_all_aws_sig_v4_hardcoded_config() {
588 serde_yaml::from_str::<Config>(
589 r#"
590 all:
591 aws_sig_v4:
592 hardcoded:
593 access_key_id: "test"
594 secret_access_key: "test"
595 region: "us-east-1"
596 service_name: "lambda"
597 "#,
598 )
599 .unwrap();
600 }
601
602 #[test]
603 fn test_subgraph_aws_sig_v4_hardcoded_config() {
604 serde_yaml::from_str::<Config>(
605 r#"
606 subgraphs:
607 products:
608 aws_sig_v4:
609 hardcoded:
610 access_key_id: "test"
611 secret_access_key: "test"
612 region: "us-east-1"
613 service_name: "test_service"
614 "#,
615 )
616 .unwrap();
617 }
618
619 #[test]
620 fn test_aws_sig_v4_default_chain_assume_role_config() {
621 serde_yaml::from_str::<Config>(
622 r#"
623 all:
624 aws_sig_v4:
625 default_chain:
626 profile_name: "my-test-profile"
627 region: "us-east-1"
628 service_name: "lambda"
629 assume_role:
630 role_arn: "test-arn"
631 session_name: "test-session"
632 external_id: "test-id"
633 "#,
634 )
635 .unwrap();
636 }
637
638 #[tokio::test]
639 async fn test_lattice_body_payload_should_be_unsigned() -> Result<(), BoxError> {
640 let subgraph_request = example_request();
641
642 let mut mock = MockSubgraphService::new();
643 mock.expect_call()
644 .times(1)
645 .withf(|request| {
646 let http_request = get_signed_request(request, "products".to_string());
647 assert_eq!(
648 "UNSIGNED-PAYLOAD",
649 http_request
650 .headers()
651 .get("x-amz-content-sha256")
652 .unwrap()
653 .to_str()
654 .unwrap()
655 );
656 true
657 })
658 .returning(example_response);
659
660 let mut service = SubgraphAuth {
661 signing_params: Arc::new(SigningParams {
662 all: make_signing_params(
663 &AuthConfig::AWSSigV4(AWSSigV4Config::Hardcoded(AWSSigV4HardcodedConfig {
664 access_key_id: "id".to_string(),
665 secret_access_key: "secret".to_string(),
666 region: "us-east-1".to_string(),
667 service_name: "vpc-lattice-svcs".to_string(),
668 assume_role: None,
669 })),
670 "all",
671 )
672 .await
673 .ok()
674 .map(Arc::new),
675 subgraphs: Default::default(),
676 }),
677 }
678 .subgraph_service("test_subgraph", mock.boxed());
679
680 service.ready().await?.call(subgraph_request).await?;
681 Ok(())
682 }
683
684 #[tokio::test]
685 async fn test_aws_sig_v4_headers() -> Result<(), BoxError> {
686 let subgraph_request = example_request();
687
688 let mut mock = MockSubgraphService::new();
689 mock.expect_call()
690 .times(1)
691 .withf(|request| {
692 let http_request = get_signed_request(request, "products".to_string());
693 let authorization_regex = Regex::new(r"AWS4-HMAC-SHA256 Credential=id/\d{8}/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=[a-f0-9]{64}").unwrap();
694 let authorization_header_str = http_request.headers().get("authorization").unwrap().to_str().unwrap();
695 assert_eq!(match authorization_regex.find(authorization_header_str) {
696 Some(m) => m.as_str(),
697 None => "no match"
698 }, authorization_header_str);
699
700 let x_amz_date_regex = Regex::new(r"\d{8}T\d{6}Z").unwrap();
701 let x_amz_date_header_str = http_request.headers().get("x-amz-date").unwrap().to_str().unwrap();
702 assert_eq!(match x_amz_date_regex.find(x_amz_date_header_str) {
703 Some(m) => m.as_str(),
704 None => "no match"
705 }, x_amz_date_header_str);
706
707 assert_eq!(http_request.headers().get("x-amz-content-sha256").unwrap(), "255959b4c6e11c1080f61ce0d75eb1b565c1772173335a7828ba9c13c25c0d8c");
708
709 true
710 })
711 .returning(example_response);
712
713 let mut service = SubgraphAuth {
714 signing_params: Arc::new(SigningParams {
715 all: make_signing_params(
716 &AuthConfig::AWSSigV4(AWSSigV4Config::Hardcoded(AWSSigV4HardcodedConfig {
717 access_key_id: "id".to_string(),
718 secret_access_key: "secret".to_string(),
719 region: "us-east-1".to_string(),
720 service_name: "s3".to_string(),
721 assume_role: None,
722 })),
723 "all",
724 )
725 .await
726 .ok()
727 .map(Arc::new),
728 subgraphs: Default::default(),
729 }),
730 }
731 .subgraph_service("test_subgraph", mock.boxed());
732
733 service.ready().await?.call(subgraph_request).await?;
734 Ok(())
735 }
736
737 #[tokio::test]
738 async fn test_credentials_provider_keeps_credentials_in_cache() -> Result<(), BoxError> {
739 #[derive(Debug, Default, Clone)]
740 struct TestCredentialsProvider {
741 times_called: Arc<AtomicUsize>,
742 }
743
744 impl ProvideCredentials for TestCredentialsProvider {
745 fn provide_credentials<'a>(
746 &'a self,
747 ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
748 where
749 Self: 'a,
750 {
751 self.times_called.fetch_add(1, Ordering::SeqCst);
752 aws_credential_types::provider::future::ProvideCredentials::ready(Ok(
753 Credentials::new("test_key", "test_secret", None, None, "test_provider"),
754 ))
755 }
756 }
757
758 let tcp = TestCredentialsProvider::default();
759
760 let cp = CredentialsProvider::from_provide_credentials(tcp.clone())
761 .await
762 .unwrap();
763
764 let _ = cp.provide_credentials().await.unwrap();
765 let _ = cp.provide_credentials().await.unwrap();
766
767 assert_eq!(1, tcp.times_called.load(Ordering::SeqCst));
768
769 cp.refresh_credentials().await;
770 tokio::time::sleep(Duration::from_millis(50)).await;
771
772 let _ = cp.provide_credentials().await.unwrap();
773 let _ = cp.provide_credentials().await.unwrap();
774
775 assert_eq!(2, tcp.times_called.load(Ordering::SeqCst));
776
777 Ok(())
778 }
779
780 #[tokio::test]
781 async fn test_credentials_provider_refresh_on_stale() -> Result<(), BoxError> {
782 #[derive(Debug, Default, Clone)]
783 struct TestCredentialsProvider {
784 times_called: Arc<AtomicUsize>,
785 }
786
787 impl ProvideCredentials for TestCredentialsProvider {
788 fn provide_credentials<'a>(
789 &'a self,
790 ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
791 where
792 Self: 'a,
793 {
794 self.times_called.fetch_add(1, Ordering::SeqCst);
795 aws_credential_types::provider::future::ProvideCredentials::ready(Ok(
796 Credentials::new(
798 "test_key",
799 "test_secret",
800 None,
801 SystemTime::now().checked_add(Duration::from_secs(60 * 5 + 1)),
803 "test_provider",
804 ),
805 ))
806 }
807 }
808
809 let tcp = TestCredentialsProvider::default();
810
811 let cp = CredentialsProvider::from_provide_credentials(tcp.clone())
812 .await
813 .unwrap();
814
815 let _ = cp.provide_credentials().await.unwrap();
816 let _ = cp.provide_credentials().await.unwrap();
817
818 assert_eq!(1, tcp.times_called.load(Ordering::SeqCst));
819
820 cp.refresh_credentials().await;
821 tokio::time::sleep(Duration::from_millis(50)).await;
822
823 let _ = cp.provide_credentials().await.unwrap();
824 let _ = cp.provide_credentials().await.unwrap();
825
826 assert_eq!(2, tcp.times_called.load(Ordering::SeqCst));
827
828 tokio::time::sleep(Duration::from_secs(1)).await;
829
830 assert_eq!(3, tcp.times_called.load(Ordering::SeqCst));
831
832 Ok(())
833 }
834
835 fn example_response(req: SubgraphRequest) -> Result<SubgraphResponse, BoxError> {
836 Ok(SubgraphResponse::new_from_response(
837 http::Response::default(),
838 Context::new(),
839 req.subgraph_name,
840 SubgraphRequestId(String::new()),
841 ))
842 }
843
844 fn example_request() -> SubgraphRequest {
845 SubgraphRequest::builder()
846 .supergraph_request(Arc::new(
847 http::Request::builder()
848 .header(HOST, "host")
849 .header(CONTENT_LENGTH, "2")
850 .header(CONTENT_TYPE, "graphql")
851 .body(
852 Request::builder()
853 .query("query")
854 .operation_name("my_operation_name")
855 .build(),
856 )
857 .expect("expecting valid request"),
858 ))
859 .subgraph_request(
860 http::Request::builder()
861 .header(HOST, "rhost")
862 .header(CONTENT_LENGTH, "22")
863 .header(CONTENT_TYPE, "graphql")
864 .uri("https://test-endpoint.com")
865 .body(Request::builder().query("query").build())
866 .expect("expecting valid request"),
867 )
868 .operation_kind(OperationKind::Query)
869 .context(Context::new())
870 .subgraph_name(String::default())
871 .build()
872 }
873
874 fn get_signed_request(
875 request: &SubgraphRequest,
876 service_name: String,
877 ) -> http::Request<RouterBody> {
878 let signing_params = request
879 .context
880 .extensions()
881 .with_lock(|lock| lock.get::<Arc<SigningParamsConfig>>().cloned())
882 .unwrap();
883
884 let http_request = request
885 .clone()
886 .subgraph_request
887 .map(|body| router::body::from_bytes(serde_json::to_string(&body).unwrap()));
888
889 std::thread::spawn(move || {
890 let rt = tokio::runtime::Runtime::new().unwrap();
891 rt.block_on(async {
892 signing_params
893 .sign(http_request, service_name.as_str())
894 .await
895 .unwrap()
896 })
897 })
898 .join()
899 .unwrap()
900 }
901}