Skip to main content

apollo_router/plugins/authentication/
subgraph.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4use std::time::SystemTime;
5
6use aws_config::provider_config::ProviderConfig;
7use aws_credential_types::Credentials;
8use aws_credential_types::provider::ProvideCredentials;
9use aws_credential_types::provider::error::CredentialsError;
10use aws_sigv4::http_request::PayloadChecksumKind;
11use aws_sigv4::http_request::SignableBody;
12use aws_sigv4::http_request::SignableRequest;
13use aws_sigv4::http_request::SigningSettings;
14use aws_sigv4::http_request::sign;
15use aws_smithy_async::rt::sleep::TokioSleep;
16use aws_smithy_async::time::SystemTimeSource;
17use aws_smithy_http_client::tls::Provider;
18use aws_smithy_http_client::tls::rustls_provider::CryptoMode;
19use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
20use aws_smithy_runtime_api::client::identity::Identity;
21use aws_types::SdkConfig;
22use aws_types::region::Region;
23use aws_types::sdk_config::SharedCredentialsProvider;
24use http::HeaderMap;
25use http::Request;
26use parking_lot::RwLock;
27use schemars::JsonSchema;
28use serde::Deserialize;
29use serde::Serialize;
30use tokio::sync::mpsc::Sender;
31use tokio::task::JoinHandle;
32use tower::BoxError;
33use tower::ServiceBuilder;
34use tower::ServiceExt;
35
36use crate::services::SubgraphRequest;
37use crate::services::router;
38use crate::services::router::body::RouterBody;
39
40/// Hardcoded Config using access_key and secret.
41/// Prefer using DefaultChain instead.
42#[derive(Clone, JsonSchema, Deserialize, Serialize, Debug)]
43#[serde(rename_all = "snake_case", deny_unknown_fields)]
44pub(crate) struct AWSSigV4HardcodedConfig {
45    /// The ID for this access key.
46    access_key_id: String,
47    /// The secret key used to sign requests.
48    secret_access_key: String,
49    /// The AWS region this chain applies to.
50    region: String,
51    /// The service you're trying to access, eg: "s3", "vpc-lattice-svcs", etc.
52    service_name: String,
53    /// Specify assumed role configuration.
54    assume_role: Option<AssumeRoleProvider>,
55}
56
57impl ProvideCredentials for AWSSigV4HardcodedConfig {
58    fn provide_credentials<'a>(
59        &'a self,
60    ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
61    where
62        Self: 'a,
63    {
64        aws_credential_types::provider::future::ProvideCredentials::ready(Ok(Credentials::new(
65            self.access_key_id.clone(),
66            self.secret_access_key.clone(),
67            None,
68            None,
69            "apollo-router",
70        )))
71    }
72}
73
74/// Configuration of the DefaultChainProvider
75#[derive(Clone, JsonSchema, Deserialize, Serialize, Debug)]
76#[serde(deny_unknown_fields)]
77pub(crate) struct DefaultChainConfig {
78    /// The AWS region this chain applies to.
79    region: String,
80    /// The profile name used by this provider
81    profile_name: Option<String>,
82    /// The service you're trying to access, eg: "s3", "vpc-lattice-svcs", etc.
83    service_name: String,
84    /// Specify assumed role configuration.
85    assume_role: Option<AssumeRoleProvider>,
86}
87
88/// Specify assumed role configuration.
89#[derive(Clone, JsonSchema, Deserialize, Serialize, Debug)]
90#[serde(deny_unknown_fields)]
91pub(crate) struct AssumeRoleProvider {
92    /// Amazon Resource Name (ARN)
93    /// for the role assumed when making requests
94    role_arn: String,
95    /// Uniquely identify a session when the same role is assumed by different principals or for different reasons.
96    session_name: String,
97    /// Unique identifier that might be required when you assume a role in another account.
98    external_id: Option<String>,
99}
100
101/// Configure AWS sigv4 auth.
102#[derive(Clone, JsonSchema, Deserialize, Serialize, Debug)]
103#[serde(rename_all = "snake_case")]
104pub(crate) enum AWSSigV4Config {
105    Hardcoded(AWSSigV4HardcodedConfig),
106    DefaultChain(DefaultChainConfig),
107}
108
109impl AWSSigV4Config {
110    async fn get_credentials_provider(&self) -> Arc<dyn ProvideCredentials> {
111        let region = self.region();
112
113        let role_provider_builder = self.assume_role().map(|assume_role_provider| {
114            let rp =
115                aws_config::sts::AssumeRoleProvider::builder(assume_role_provider.role_arn.clone())
116                    .configure(
117                        &SdkConfig::builder()
118                            .http_client(
119                                aws_smithy_http_client::Builder::new()
120                                    .tls_provider(Provider::Rustls(CryptoMode::Ring))
121                                    .build_https(),
122                            )
123                            .sleep_impl(TokioSleep::new())
124                            .time_source(SystemTimeSource::new())
125                            .behavior_version(BehaviorVersion::latest())
126                            .build(),
127                    )
128                    .session_name(assume_role_provider.session_name.clone())
129                    .region(region.clone());
130            if let Some(external_id) = &assume_role_provider.external_id {
131                rp.external_id(external_id.as_str())
132            } else {
133                rp
134            }
135        });
136
137        match self {
138            Self::DefaultChain(config) => {
139                let aws_config = credentials_chain_builder().region(region.clone());
140
141                let aws_config = if let Some(profile_name) = &config.profile_name {
142                    aws_config.profile_name(profile_name.as_str())
143                } else {
144                    aws_config
145                };
146
147                let chain = aws_config.build().await;
148                if let Some(assume_role_provider) = role_provider_builder {
149                    Arc::new(assume_role_provider.build_from_provider(chain).await)
150                } else {
151                    Arc::new(chain)
152                }
153            }
154            Self::Hardcoded(config) => {
155                let chain = credentials_chain_builder().build().await;
156                if let Some(assume_role_provider) = role_provider_builder {
157                    Arc::new(assume_role_provider.build_from_provider(chain).await)
158                } else {
159                    Arc::new(config.clone())
160                }
161            }
162        }
163    }
164
165    fn region(&self) -> Region {
166        let region = match self {
167            Self::DefaultChain(config) => config.region.clone(),
168            Self::Hardcoded(config) => config.region.clone(),
169        };
170        aws_types::region::Region::new(region)
171    }
172
173    fn service_name(&self) -> String {
174        match self {
175            Self::DefaultChain(config) => config.service_name.clone(),
176            Self::Hardcoded(config) => config.service_name.clone(),
177        }
178    }
179
180    fn assume_role(&self) -> Option<AssumeRoleProvider> {
181        match self {
182            Self::DefaultChain(config) => config.assume_role.clone(),
183            Self::Hardcoded(config) => config.assume_role.clone(),
184        }
185    }
186}
187
188fn credentials_chain_builder() -> aws_config::default_provider::credentials::Builder {
189    aws_config::default_provider::credentials::DefaultCredentialsChain::builder().configure(
190        ProviderConfig::default()
191            .with_http_client(
192                aws_smithy_http_client::Builder::new()
193                    .tls_provider(Provider::Rustls(CryptoMode::Ring))
194                    .build_https(),
195            )
196            .with_sleep_impl(TokioSleep::new())
197            .with_time_source(SystemTimeSource::new()),
198    )
199}
200
201#[derive(Clone, Debug, JsonSchema, Deserialize, Serialize)]
202#[serde(deny_unknown_fields)]
203pub(crate) enum AuthConfig {
204    #[serde(rename = "aws_sig_v4")]
205    AWSSigV4(AWSSigV4Config),
206}
207
208/// Configure subgraph authentication
209#[derive(Clone, Debug, Default, JsonSchema, Deserialize)]
210#[serde(rename_all = "snake_case", deny_unknown_fields)]
211#[schemars(rename = "AuthenticationSubgraphConfig")]
212pub(crate) struct Config {
213    /// Configuration that will apply to all subgraphs.
214    #[serde(default)]
215    pub(crate) all: Option<AuthConfig>,
216    #[serde(default)]
217    /// Create a configuration that will apply only to a specific subgraph.
218    pub(crate) subgraphs: HashMap<String, AuthConfig>,
219}
220
221#[allow(dead_code)]
222#[derive(Clone, Default)]
223pub(crate) struct SigningParams {
224    pub(crate) all: Option<Arc<SigningParamsConfig>>,
225    pub(crate) subgraphs: HashMap<String, Arc<SigningParamsConfig>>,
226}
227
228#[derive(Clone)]
229pub(crate) struct SigningParamsConfig {
230    credentials_provider: CredentialsProvider,
231    region: Region,
232    service_name: String,
233    subgraph_name: String,
234}
235
236#[derive(Clone, Debug)]
237struct CredentialsProvider {
238    credentials: Arc<RwLock<Credentials>>,
239    _credentials_updater_handle: Arc<JoinHandle<()>>,
240    #[allow(dead_code)]
241    refresh_credentials: Sender<()>,
242}
243
244// Refresh token if it will expire within the next 5 minutes
245const MIN_REMAINING_DURATION: Duration = std::time::Duration::from_secs(60 * 5);
246// If the token couldn't be refreshed, try again in 1 minute
247const RETRY_DURATION: Duration = std::time::Duration::from_secs(60);
248
249impl CredentialsProvider {
250    async fn from_provide_credentials(
251        provide_credentials: impl ProvideCredentials + 'static,
252    ) -> Result<Self, CredentialsError> {
253        let credentials_provider = SharedCredentialsProvider::new(provide_credentials);
254        let (sender, mut refresh_credentials_receiver) = tokio::sync::mpsc::channel(1);
255        let credentials = credentials_provider.provide_credentials().await?;
256        let mut refresh_timer = next_refresh_timer(&credentials);
257        let credentials = Arc::new(RwLock::new(credentials));
258        let c2 = credentials.clone();
259        let crp2 = credentials_provider.clone();
260        let handle = tokio::spawn(async move {
261            loop {
262                tokio::select! {
263                    _ = tokio::time::sleep(refresh_timer.unwrap_or(Duration::MAX)) => {
264                       refresh_timer = refresh_credentials(&crp2, &c2).await;
265                    },
266                    rcr = refresh_credentials_receiver.recv() => {
267                        if rcr.is_some() {
268                            refresh_timer = refresh_credentials(&crp2, &c2).await;
269                        } else {
270                            return;
271                        }
272                    },
273                }
274            }
275        });
276        Ok(Self {
277            _credentials_updater_handle: Arc::new(handle),
278            refresh_credentials: sender,
279            credentials,
280        })
281    }
282
283    #[allow(dead_code)]
284    pub(crate) async fn refresh_credentials(&self) {
285        let _ = self.refresh_credentials.send(()).await;
286    }
287}
288
289async fn refresh_credentials(
290    credentials_provider: &(impl ProvideCredentials + 'static),
291    credentials: &RwLock<Credentials>,
292) -> Option<Duration> {
293    match credentials_provider.provide_credentials().await {
294        Ok(new_credentials) => {
295            let mut credentials = credentials.write();
296            *credentials = new_credentials;
297            next_refresh_timer(&credentials)
298        }
299        Err(e) => {
300            tracing::warn!("authentication: couldn't refresh credentials {e}");
301            Some(RETRY_DURATION)
302        }
303    }
304}
305
306fn next_refresh_timer(credentials: &Credentials) -> Option<Duration> {
307    credentials
308        .expiry()
309        .and_then(|e| e.duration_since(SystemTime::now()).ok())
310        .and_then(|d| {
311            d.checked_sub(MIN_REMAINING_DURATION)
312                .or(Some(Duration::from_secs(0)))
313        })
314}
315
316impl ProvideCredentials for CredentialsProvider {
317    fn provide_credentials<'a>(
318        &'a self,
319    ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
320    where
321        Self: 'a,
322    {
323        aws_credential_types::provider::future::ProvideCredentials::ready(Ok(self
324            .credentials
325            .read()
326            .clone()))
327    }
328}
329
330impl SigningParamsConfig {
331    pub(crate) async fn sign(
332        &self,
333        mut req: Request<RouterBody>,
334        subgraph_name: &str,
335    ) -> Result<Request<RouterBody>, BoxError> {
336        let credentials = self.credentials().await?;
337        let builder = self.signing_params_builder(&credentials).await?;
338        let (parts, body) = req.into_parts();
339        // Depending on the service, AWS refuses sigv4 payloads that contain specific headers.
340        // We'll go with default signed headers
341        let headers = HeaderMap::<&'static str>::default();
342        // UnsignedPayload only applies to lattice
343        let body_bytes = router::body::into_bytes(body).await?.to_vec();
344        let signable_request = SignableRequest::new(
345            parts.method.as_str(),
346            parts.uri.to_string(),
347            headers.iter().map(|(name, value)| (name.as_str(), *value)),
348            match self.service_name.as_str() {
349                "vpc-lattice-svcs" => SignableBody::UnsignedPayload,
350                _ => SignableBody::Bytes(body_bytes.as_slice()),
351            },
352        )?;
353
354        let signing_params = builder.build().expect("all required fields set");
355
356        let (signing_instructions, _signature) = sign(signable_request, &signing_params.into())
357            .map_err(|err| {
358                increment_failure_counter(subgraph_name);
359                let error = format!("failed to sign GraphQL body for AWS SigV4: {err}");
360                tracing::error!("{}", error);
361                error
362            })?
363            .into_parts();
364        req = Request::<RouterBody>::from_parts(parts, router::body::from_bytes(body_bytes));
365        signing_instructions.apply_to_request_http1x(&mut req);
366        increment_success_counter(subgraph_name);
367        Ok(req)
368    }
369
370    // This function is the same as above, except it's a new one because () doesn't implement HttpBody`
371    pub(crate) async fn sign_empty(
372        &self,
373        mut req: Request<()>,
374        subgraph_name: &str,
375    ) -> Result<Request<()>, BoxError> {
376        let credentials = self.credentials().await?;
377        let builder = self.signing_params_builder(&credentials).await?;
378        let (parts, _) = req.into_parts();
379        // Depending on the service, AWS refuses sigv4 payloads that contain specific headers.
380        // We'll go with default signed headers
381        let headers = HeaderMap::<&'static str>::default();
382        // UnsignedPayload only applies to lattice
383        let signable_request = SignableRequest::new(
384            parts.method.as_str(),
385            parts.uri.to_string(),
386            headers.iter().map(|(name, value)| (name.as_str(), *value)),
387            match self.service_name.as_str() {
388                "vpc-lattice-svcs" => SignableBody::UnsignedPayload,
389                _ => SignableBody::Bytes(&[]),
390            },
391        )?;
392
393        let signing_params = builder.build().expect("all required fields set");
394
395        let (signing_instructions, _signature) = sign(signable_request, &signing_params.into())
396            .map_err(|err| {
397                increment_failure_counter(subgraph_name);
398                let error = format!("failed to sign GraphQL body for AWS SigV4: {err}");
399                tracing::error!("{}", error);
400                error
401            })?
402            .into_parts();
403        req = Request::<()>::from_parts(parts, ());
404        signing_instructions.apply_to_request_http1x(&mut req);
405        increment_success_counter(subgraph_name);
406        Ok(req)
407    }
408
409    async fn signing_params_builder<'s>(
410        &'s self,
411        identity: &'s Identity,
412    ) -> Result<aws_sigv4::sign::v4::signing_params::Builder<'s, SigningSettings>, BoxError> {
413        let settings = get_signing_settings(self);
414        let builder = aws_sigv4::sign::v4::SigningParams::builder()
415            .identity(identity)
416            .region(self.region.as_ref())
417            .name(&self.service_name)
418            .time(SystemTime::now())
419            .settings(settings);
420        Ok(builder)
421    }
422
423    async fn credentials(&self) -> Result<Identity, BoxError> {
424        self.credentials_provider
425            .provide_credentials()
426            .await
427            .map_err(|err| {
428                increment_failure_counter(self.subgraph_name.as_str());
429                let error = format!("failed to get credentials for AWS SigV4 signing: {err}");
430                tracing::error!("{}", error);
431                error.into()
432            })
433            .map(Into::into)
434    }
435}
436
437fn increment_success_counter(subgraph_name: &str) {
438    u64_counter!(
439        "apollo.router.operations.authentication.aws.sigv4",
440        "Number of subgraph requests signed with AWS SigV4",
441        1,
442        authentication.aws.sigv4.failed = false,
443        subgraph.service.name = subgraph_name.to_string()
444    );
445}
446fn increment_failure_counter(subgraph_name: &str) {
447    u64_counter!(
448        "apollo.router.operations.authentication.aws.sigv4",
449        "Number of subgraph requests signed with AWS SigV4",
450        1,
451        authentication.aws.sigv4.failed = true,
452        subgraph.service.name = subgraph_name.to_string()
453    );
454}
455
456pub(super) async fn make_signing_params(
457    config: &AuthConfig,
458    subgraph_name: &str,
459) -> Result<SigningParamsConfig, BoxError> {
460    match config {
461        AuthConfig::AWSSigV4(config) => {
462            let credentials_provider = config.get_credentials_provider().await;
463            Ok(SigningParamsConfig {
464                region: config.region(),
465                service_name: config.service_name(),
466                credentials_provider: CredentialsProvider::from_provide_credentials(
467                    credentials_provider,
468                )
469                .await
470                .map_err(BoxError::from)?,
471                subgraph_name: subgraph_name.to_string(),
472            })
473        }
474    }
475}
476
477/// There are three possible cases
478/// https://github.com/awslabs/aws-sdk-rust/blob/9c3168dafa4fd8885ce4e1fd41cec55ce982a33c/sdk/aws-sigv4/src/http_request/sign.rs#L264C1-L271C6
479fn get_signing_settings(signing_params: &SigningParamsConfig) -> SigningSettings {
480    let mut settings = SigningSettings::default();
481    settings.payload_checksum_kind = match signing_params.service_name.as_str() {
482        "appsync" | "s3" | "vpc-lattice-svcs" => PayloadChecksumKind::XAmzSha256,
483        _ => PayloadChecksumKind::NoHeader,
484    };
485    settings
486}
487
488pub(super) struct SubgraphAuth {
489    pub(super) signing_params: Arc<SigningParams>,
490}
491
492impl SubgraphAuth {
493    pub(super) fn subgraph_service(
494        &self,
495        name: &str,
496        service: crate::services::subgraph::BoxService,
497    ) -> crate::services::subgraph::BoxService {
498        if let Some(signing_params) = self.params_for_service(name) {
499            ServiceBuilder::new()
500                .map_request(move |req: SubgraphRequest| {
501                    let signing_params = signing_params.clone();
502                    req.context
503                        .extensions()
504                        .with_lock(|lock| lock.insert(signing_params));
505                    req
506                })
507                .service(service)
508                .boxed()
509        } else {
510            service
511        }
512    }
513}
514
515impl SubgraphAuth {
516    fn params_for_service(&self, service_name: &str) -> Option<Arc<SigningParamsConfig>> {
517        self.signing_params
518            .subgraphs
519            .get(service_name)
520            .cloned()
521            .or_else(|| self.signing_params.all.clone())
522    }
523}
524
525#[cfg(test)]
526mod test {
527    use std::sync::Arc;
528    use std::sync::atomic::AtomicUsize;
529    use std::sync::atomic::Ordering;
530
531    use http::header::CONTENT_LENGTH;
532    use http::header::CONTENT_TYPE;
533    use http::header::HOST;
534    use regex::Regex;
535    use tower::Service;
536
537    use super::*;
538    use crate::Context;
539    use crate::graphql::Request;
540    use crate::plugin::test::MockSubgraphService;
541    use crate::query_planner::fetch::OperationKind;
542    use crate::services::SubgraphRequest;
543    use crate::services::SubgraphResponse;
544    use crate::services::subgraph::SubgraphRequestId;
545
546    async fn test_signing_settings(service_name: &str) -> SigningSettings {
547        let params: SigningParamsConfig = make_signing_params(
548            &AuthConfig::AWSSigV4(AWSSigV4Config::Hardcoded(AWSSigV4HardcodedConfig {
549                access_key_id: "id".to_string(),
550                secret_access_key: "secret".to_string(),
551                region: "us-east-1".to_string(),
552                service_name: service_name.to_string(),
553                assume_role: None,
554            })),
555            "all",
556        )
557        .await
558        .unwrap();
559        get_signing_settings(&params)
560    }
561
562    #[tokio::test]
563    async fn test_get_signing_settings() {
564        assert_eq!(
565            PayloadChecksumKind::XAmzSha256,
566            test_signing_settings("s3").await.payload_checksum_kind
567        );
568        assert_eq!(
569            PayloadChecksumKind::XAmzSha256,
570            test_signing_settings("vpc-lattice-svcs")
571                .await
572                .payload_checksum_kind
573        );
574        assert_eq!(
575            PayloadChecksumKind::XAmzSha256,
576            test_signing_settings("appsync").await.payload_checksum_kind
577        );
578        assert_eq!(
579            PayloadChecksumKind::NoHeader,
580            test_signing_settings("something-else")
581                .await
582                .payload_checksum_kind
583        );
584    }
585
586    #[test]
587    fn test_all_aws_sig_v4_hardcoded_config() {
588        serde_yaml::from_str::<Config>(
589            r#"
590        all:
591          aws_sig_v4:
592            hardcoded:
593              access_key_id: "test"
594              secret_access_key: "test"
595              region: "us-east-1"
596              service_name: "lambda"
597        "#,
598        )
599        .unwrap();
600    }
601
602    #[test]
603    fn test_subgraph_aws_sig_v4_hardcoded_config() {
604        serde_yaml::from_str::<Config>(
605            r#"
606        subgraphs:
607          products:
608            aws_sig_v4:
609              hardcoded:
610                access_key_id: "test"
611                secret_access_key: "test"
612                region: "us-east-1"
613                service_name: "test_service"
614        "#,
615        )
616        .unwrap();
617    }
618
619    #[test]
620    fn test_aws_sig_v4_default_chain_assume_role_config() {
621        serde_yaml::from_str::<Config>(
622            r#"
623        all:
624            aws_sig_v4:
625                default_chain:
626                    profile_name: "my-test-profile"
627                    region: "us-east-1"
628                    service_name: "lambda"
629                    assume_role:
630                        role_arn: "test-arn"
631                        session_name: "test-session"
632                        external_id: "test-id"
633        "#,
634        )
635        .unwrap();
636    }
637
638    #[tokio::test]
639    async fn test_lattice_body_payload_should_be_unsigned() -> Result<(), BoxError> {
640        let subgraph_request = example_request();
641
642        let mut mock = MockSubgraphService::new();
643        mock.expect_call()
644            .times(1)
645            .withf(|request| {
646                let http_request = get_signed_request(request, "products".to_string());
647                assert_eq!(
648                    "UNSIGNED-PAYLOAD",
649                    http_request
650                        .headers()
651                        .get("x-amz-content-sha256")
652                        .unwrap()
653                        .to_str()
654                        .unwrap()
655                );
656                true
657            })
658            .returning(example_response);
659
660        let mut service = SubgraphAuth {
661            signing_params: Arc::new(SigningParams {
662                all: make_signing_params(
663                    &AuthConfig::AWSSigV4(AWSSigV4Config::Hardcoded(AWSSigV4HardcodedConfig {
664                        access_key_id: "id".to_string(),
665                        secret_access_key: "secret".to_string(),
666                        region: "us-east-1".to_string(),
667                        service_name: "vpc-lattice-svcs".to_string(),
668                        assume_role: None,
669                    })),
670                    "all",
671                )
672                .await
673                .ok()
674                .map(Arc::new),
675                subgraphs: Default::default(),
676            }),
677        }
678        .subgraph_service("test_subgraph", mock.boxed());
679
680        service.ready().await?.call(subgraph_request).await?;
681        Ok(())
682    }
683
684    #[tokio::test]
685    async fn test_aws_sig_v4_headers() -> Result<(), BoxError> {
686        let subgraph_request = example_request();
687
688        let mut mock = MockSubgraphService::new();
689        mock.expect_call()
690            .times(1)
691            .withf(|request| {
692                let http_request = get_signed_request(request, "products".to_string());
693                let authorization_regex = Regex::new(r"AWS4-HMAC-SHA256 Credential=id/\d{8}/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=[a-f0-9]{64}").unwrap();
694                let authorization_header_str = http_request.headers().get("authorization").unwrap().to_str().unwrap();
695                assert_eq!(match authorization_regex.find(authorization_header_str) {
696                    Some(m) => m.as_str(),
697                    None => "no match"
698                }, authorization_header_str);
699
700                let x_amz_date_regex = Regex::new(r"\d{8}T\d{6}Z").unwrap();
701                let x_amz_date_header_str = http_request.headers().get("x-amz-date").unwrap().to_str().unwrap();
702                assert_eq!(match x_amz_date_regex.find(x_amz_date_header_str) {
703                    Some(m) => m.as_str(),
704                    None => "no match"
705                }, x_amz_date_header_str);
706
707                assert_eq!(http_request.headers().get("x-amz-content-sha256").unwrap(), "255959b4c6e11c1080f61ce0d75eb1b565c1772173335a7828ba9c13c25c0d8c");
708
709                true
710            })
711            .returning(example_response);
712
713        let mut service = SubgraphAuth {
714            signing_params: Arc::new(SigningParams {
715                all: make_signing_params(
716                    &AuthConfig::AWSSigV4(AWSSigV4Config::Hardcoded(AWSSigV4HardcodedConfig {
717                        access_key_id: "id".to_string(),
718                        secret_access_key: "secret".to_string(),
719                        region: "us-east-1".to_string(),
720                        service_name: "s3".to_string(),
721                        assume_role: None,
722                    })),
723                    "all",
724                )
725                .await
726                .ok()
727                .map(Arc::new),
728                subgraphs: Default::default(),
729            }),
730        }
731        .subgraph_service("test_subgraph", mock.boxed());
732
733        service.ready().await?.call(subgraph_request).await?;
734        Ok(())
735    }
736
737    #[tokio::test]
738    async fn test_credentials_provider_keeps_credentials_in_cache() -> Result<(), BoxError> {
739        #[derive(Debug, Default, Clone)]
740        struct TestCredentialsProvider {
741            times_called: Arc<AtomicUsize>,
742        }
743
744        impl ProvideCredentials for TestCredentialsProvider {
745            fn provide_credentials<'a>(
746                &'a self,
747            ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
748            where
749                Self: 'a,
750            {
751                self.times_called.fetch_add(1, Ordering::SeqCst);
752                aws_credential_types::provider::future::ProvideCredentials::ready(Ok(
753                    Credentials::new("test_key", "test_secret", None, None, "test_provider"),
754                ))
755            }
756        }
757
758        let tcp = TestCredentialsProvider::default();
759
760        let cp = CredentialsProvider::from_provide_credentials(tcp.clone())
761            .await
762            .unwrap();
763
764        let _ = cp.provide_credentials().await.unwrap();
765        let _ = cp.provide_credentials().await.unwrap();
766
767        assert_eq!(1, tcp.times_called.load(Ordering::SeqCst));
768
769        cp.refresh_credentials().await;
770        tokio::time::sleep(Duration::from_millis(50)).await;
771
772        let _ = cp.provide_credentials().await.unwrap();
773        let _ = cp.provide_credentials().await.unwrap();
774
775        assert_eq!(2, tcp.times_called.load(Ordering::SeqCst));
776
777        Ok(())
778    }
779
780    #[tokio::test]
781    async fn test_credentials_provider_refresh_on_stale() -> Result<(), BoxError> {
782        #[derive(Debug, Default, Clone)]
783        struct TestCredentialsProvider {
784            times_called: Arc<AtomicUsize>,
785        }
786
787        impl ProvideCredentials for TestCredentialsProvider {
788            fn provide_credentials<'a>(
789                &'a self,
790            ) -> aws_credential_types::provider::future::ProvideCredentials<'a>
791            where
792                Self: 'a,
793            {
794                self.times_called.fetch_add(1, Ordering::SeqCst);
795                aws_credential_types::provider::future::ProvideCredentials::ready(Ok(
796                    // The token will expire immediately, it should be refreshed fairly fast
797                    Credentials::new(
798                        "test_key",
799                        "test_secret",
800                        None,
801                        // 5 minutes + 1 second
802                        SystemTime::now().checked_add(Duration::from_secs(60 * 5 + 1)),
803                        "test_provider",
804                    ),
805                ))
806            }
807        }
808
809        let tcp = TestCredentialsProvider::default();
810
811        let cp = CredentialsProvider::from_provide_credentials(tcp.clone())
812            .await
813            .unwrap();
814
815        let _ = cp.provide_credentials().await.unwrap();
816        let _ = cp.provide_credentials().await.unwrap();
817
818        assert_eq!(1, tcp.times_called.load(Ordering::SeqCst));
819
820        cp.refresh_credentials().await;
821        tokio::time::sleep(Duration::from_millis(50)).await;
822
823        let _ = cp.provide_credentials().await.unwrap();
824        let _ = cp.provide_credentials().await.unwrap();
825
826        assert_eq!(2, tcp.times_called.load(Ordering::SeqCst));
827
828        tokio::time::sleep(Duration::from_secs(1)).await;
829
830        assert_eq!(3, tcp.times_called.load(Ordering::SeqCst));
831
832        Ok(())
833    }
834
835    fn example_response(req: SubgraphRequest) -> Result<SubgraphResponse, BoxError> {
836        Ok(SubgraphResponse::new_from_response(
837            http::Response::default(),
838            Context::new(),
839            req.subgraph_name,
840            SubgraphRequestId(String::new()),
841        ))
842    }
843
844    fn example_request() -> SubgraphRequest {
845        SubgraphRequest::builder()
846            .supergraph_request(Arc::new(
847                http::Request::builder()
848                    .header(HOST, "host")
849                    .header(CONTENT_LENGTH, "2")
850                    .header(CONTENT_TYPE, "graphql")
851                    .body(
852                        Request::builder()
853                            .query("query")
854                            .operation_name("my_operation_name")
855                            .build(),
856                    )
857                    .expect("expecting valid request"),
858            ))
859            .subgraph_request(
860                http::Request::builder()
861                    .header(HOST, "rhost")
862                    .header(CONTENT_LENGTH, "22")
863                    .header(CONTENT_TYPE, "graphql")
864                    .uri("https://test-endpoint.com")
865                    .body(Request::builder().query("query").build())
866                    .expect("expecting valid request"),
867            )
868            .operation_kind(OperationKind::Query)
869            .context(Context::new())
870            .subgraph_name(String::default())
871            .build()
872    }
873
874    fn get_signed_request(
875        request: &SubgraphRequest,
876        service_name: String,
877    ) -> http::Request<RouterBody> {
878        let signing_params = request
879            .context
880            .extensions()
881            .with_lock(|lock| lock.get::<Arc<SigningParamsConfig>>().cloned())
882            .unwrap();
883
884        let http_request = request
885            .clone()
886            .subgraph_request
887            .map(|body| router::body::from_bytes(serde_json::to_string(&body).unwrap()));
888
889        std::thread::spawn(move || {
890            let rt = tokio::runtime::Runtime::new().unwrap();
891            rt.block_on(async {
892                signing_params
893                    .sign(http_request, service_name.as_str())
894                    .await
895                    .unwrap()
896            })
897        })
898        .join()
899        .unwrap()
900    }
901}