Skip to main content

aws_config/
ecs.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Ecs Credentials Provider
7//!
8//! This credential provider is frequently used with an AWS-provided credentials service (e.g.
9//! [IAM Roles for tasks](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html)).
10//! However, it's possible to use environment variables to configure this provider to use your own
11//! credentials sources.
12//!
13//! This provider is part of the [default credentials chain](crate::default_provider::credentials).
14//!
15//! ## Configuration
16//! **First**: It will check the value of `$AWS_CONTAINER_CREDENTIALS_RELATIVE_URI`. It will use this
17//! to construct a URI rooted at `http://169.254.170.2`. For example, if the value of the environment
18//! variable was `/credentials`, the SDK would look for credentials at `http://169.254.170.2/credentials`.
19//!
20//! **Next**: It will check the value of `$AWS_CONTAINER_CREDENTIALS_FULL_URI`. This specifies the full
21//! URL to load credentials. The URL MUST satisfy one of the following three properties:
22//! 1. The URL begins with `https`
23//! 2. The URL refers to an allowed IP address. If a URL contains a domain name instead of an IP address,
24//!    a DNS lookup will be performed. ALL resolved IP addresses MUST refer to an allowed IP address, or
25//!    the credentials provider will return `CredentialsError::InvalidConfiguration`. Valid IP addresses are:
26//!    a) Loopback interfaces
27//!    b) The [ECS Task Metadata V2](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-metadata-endpoint-v2.html)
28//!    address ie 169.254.170.2.
29//!    c) [EKS Pod Identity](https://docs.aws.amazon.com/eks/latest/userguide/pod-identities.html) addresses
30//!    ie 169.254.170.23 or fd00:ec2::23
31//!
32//! **Next**: It will check the value of `$AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE`. If this is set,
33//! the filename specified will be read, and the value passed in the `Authorization` header. If the file
34//! cannot be read, an error is returned.
35//!
36//! **Finally**: It will check the value of `$AWS_CONTAINER_AUTHORIZATION_TOKEN`. If this is set, the
37//! value will be passed in the `Authorization` header.
38//!
39//! ## Credentials Format
40//! Credentials MUST be returned in a JSON format:
41//! ```json
42//! {
43//!    "AccessKeyId" : "MUA...",
44//!    "SecretAccessKey" : "/7PC5om....",
45//!    "Token" : "AQoDY....=",
46//!    "Expiration" : "2016-02-25T06:03:31Z"
47//!  }
48//! ```
49//!
50//! Credentials errors MAY be returned with a `code` and `message` field:
51//! ```json
52//! {
53//!   "code": "ErrorCode",
54//!   "message": "Helpful error message."
55//! }
56//! ```
57
58use crate::http_credential_provider::HttpCredentialProvider;
59use crate::provider_config::ProviderConfig;
60use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
61use aws_smithy_http::endpoint::apply_endpoint;
62use aws_smithy_runtime_api::client::dns::{ResolveDns, ResolveDnsError, SharedDnsResolver};
63use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
64use aws_smithy_runtime_api::shared::IntoShared;
65use aws_smithy_types::error::display::DisplayErrorContext;
66use aws_types::os_shim_internal::{Env, Fs};
67use http::header::InvalidHeaderValue;
68use http::uri::{InvalidUri, PathAndQuery, Scheme};
69use http::{HeaderValue, Uri};
70use std::error::Error;
71use std::fmt::{Display, Formatter};
72use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
73use std::time::Duration;
74use tokio::sync::OnceCell;
75
76const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
77const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
78
79// URL from https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-metadata-endpoint-v2.html
80const BASE_HOST: &str = "http://169.254.170.2";
81const ENV_RELATIVE_URI: &str = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI";
82const ENV_FULL_URI: &str = "AWS_CONTAINER_CREDENTIALS_FULL_URI";
83const ENV_AUTHORIZATION_TOKEN: &str = "AWS_CONTAINER_AUTHORIZATION_TOKEN";
84const ENV_AUTHORIZATION_TOKEN_FILE: &str = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE";
85
86/// Credential provider for ECS and generalized HTTP credentials
87///
88/// See the [module](crate::ecs) documentation for more details.
89///
90/// This credential provider is part of the default chain.
91#[derive(Debug)]
92pub struct EcsCredentialsProvider {
93    inner: OnceCell<Provider>,
94    env: Env,
95    fs: Fs,
96    builder: Builder,
97}
98
99impl EcsCredentialsProvider {
100    /// Builder for [`EcsCredentialsProvider`]
101    pub fn builder() -> Builder {
102        Builder::default()
103    }
104
105    /// Load credentials from this credentials provider
106    pub async fn credentials(&self) -> provider::Result {
107        let env_token_file = self.env.get(ENV_AUTHORIZATION_TOKEN_FILE).ok();
108        let env_token = self.env.get(ENV_AUTHORIZATION_TOKEN).ok();
109        let auth = if let Some(auth_token_file) = env_token_file {
110            let auth = self
111                .fs
112                .read_to_end(auth_token_file)
113                .await
114                .map_err(CredentialsError::provider_error)?;
115            Some(HeaderValue::from_bytes(auth.as_slice()).map_err(|err| {
116                tracing::warn!(
117                    token_length = auth.len(),
118                    ends_with_whitespace = auth
119                        .last()
120                        .map(|b| b.is_ascii_whitespace())
121                        .unwrap_or(false),
122                    "invalid auth token from file"
123                );
124                CredentialsError::invalid_configuration(EcsConfigurationError::InvalidAuthToken {
125                    err,
126                })
127            })?)
128        } else if let Some(auth_token) = env_token {
129            Some(HeaderValue::from_str(&auth_token).map_err(|err| {
130                tracing::warn!(
131                    token_length = auth_token.len(),
132                    ends_with_whitespace = auth_token
133                        .chars()
134                        .last()
135                        .map(|c| c.is_ascii_whitespace())
136                        .unwrap_or(false),
137                    "invalid auth token from env"
138                );
139                CredentialsError::invalid_configuration(EcsConfigurationError::InvalidAuthToken {
140                    err,
141                })
142            })?)
143        } else {
144            None
145        };
146        match self.provider().await {
147            Provider::NotConfigured => {
148                Err(CredentialsError::not_loaded("ECS provider not configured"))
149            }
150            Provider::InvalidConfiguration(err) => {
151                Err(CredentialsError::invalid_configuration(format!("{err}")))
152            }
153            Provider::Configured(provider) => provider.credentials(auth).await,
154        }
155    }
156
157    async fn provider(&self) -> &Provider {
158        self.inner
159            .get_or_init(|| Provider::make(self.builder.clone()))
160            .await
161    }
162}
163
164impl ProvideCredentials for EcsCredentialsProvider {
165    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
166    where
167        Self: 'a,
168    {
169        future::ProvideCredentials::new(self.credentials())
170    }
171}
172
173/// Inner Provider that can record failed configuration state
174#[derive(Debug)]
175#[allow(clippy::large_enum_variant)]
176enum Provider {
177    Configured(HttpCredentialProvider),
178    NotConfigured,
179    InvalidConfiguration(EcsConfigurationError),
180}
181
182impl Provider {
183    async fn uri(env: Env, dns: Option<SharedDnsResolver>) -> Result<Uri, EcsConfigurationError> {
184        let relative_uri = env.get(ENV_RELATIVE_URI).ok();
185        let full_uri = env.get(ENV_FULL_URI).ok();
186        if let Some(relative_uri) = relative_uri {
187            Self::build_full_uri(relative_uri)
188        } else if let Some(full_uri) = full_uri {
189            let dns = dns.or_else(default_dns);
190            validate_full_uri(&full_uri, dns)
191                .await
192                .map_err(|err| EcsConfigurationError::InvalidFullUri { err, uri: full_uri })
193        } else {
194            Err(EcsConfigurationError::NotConfigured)
195        }
196    }
197
198    async fn make(builder: Builder) -> Self {
199        let provider_config = builder.provider_config.unwrap_or_default();
200        let env = provider_config.env();
201        let uri = match Self::uri(env, builder.dns).await {
202            Ok(uri) => uri,
203            Err(EcsConfigurationError::NotConfigured) => return Provider::NotConfigured,
204            Err(err) => return Provider::InvalidConfiguration(err),
205        };
206        let path_and_query = match uri.path_and_query() {
207            Some(path_and_query) => path_and_query.to_string(),
208            None => uri.path().to_string(),
209        };
210        let endpoint = {
211            let mut parts = uri.into_parts();
212            parts.path_and_query = Some(PathAndQuery::from_static("/"));
213            Uri::from_parts(parts)
214        }
215        .expect("parts will be valid")
216        .to_string();
217
218        let http_provider = HttpCredentialProvider::builder()
219            .configure(&provider_config)
220            .http_connector_settings(
221                HttpConnectorSettings::builder()
222                    .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
223                    .read_timeout(DEFAULT_READ_TIMEOUT)
224                    .build(),
225            )
226            .build("EcsContainer", &endpoint, path_and_query);
227        Provider::Configured(http_provider)
228    }
229
230    fn build_full_uri(relative_uri: String) -> Result<Uri, EcsConfigurationError> {
231        let mut relative_uri = match relative_uri.parse::<Uri>() {
232            Ok(uri) => uri,
233            Err(invalid_uri) => {
234                tracing::warn!(uri = %DisplayErrorContext(&invalid_uri), "invalid URI loaded from environment");
235                return Err(EcsConfigurationError::InvalidRelativeUri {
236                    err: invalid_uri,
237                    uri: relative_uri,
238                });
239            }
240        };
241        let endpoint = Uri::from_static(BASE_HOST);
242        apply_endpoint(&mut relative_uri, &endpoint, None)
243            .expect("appending relative URLs to the ECS endpoint should always succeed");
244        Ok(relative_uri)
245    }
246}
247
248#[derive(Debug)]
249enum EcsConfigurationError {
250    InvalidRelativeUri {
251        err: InvalidUri,
252        uri: String,
253    },
254    InvalidFullUri {
255        err: InvalidFullUriError,
256        uri: String,
257    },
258    InvalidAuthToken {
259        err: InvalidHeaderValue,
260    },
261    NotConfigured,
262}
263
264impl Display for EcsConfigurationError {
265    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
266        match self {
267            EcsConfigurationError::InvalidRelativeUri { err, uri } => {
268                write!(f, "invalid relative URI for ECS provider ({err}): {uri}",)
269            }
270            EcsConfigurationError::InvalidFullUri { err, uri } => {
271                write!(f, "invalid full URI for ECS provider ({err}): {uri}")
272            }
273            EcsConfigurationError::NotConfigured => write!(
274                f,
275                "No environment variables were set to configure ECS provider"
276            ),
277            EcsConfigurationError::InvalidAuthToken { err } => write!(
278                f,
279                "the auth token could not be used as an HTTP header value. {err}",
280            ),
281        }
282    }
283}
284
285impl Error for EcsConfigurationError {
286    fn source(&self) -> Option<&(dyn Error + 'static)> {
287        match &self {
288            EcsConfigurationError::InvalidRelativeUri { err, .. } => Some(err),
289            EcsConfigurationError::InvalidFullUri { err, .. } => Some(err),
290            EcsConfigurationError::InvalidAuthToken { err, .. } => Some(err),
291            EcsConfigurationError::NotConfigured => None,
292        }
293    }
294}
295
296/// Builder for [`EcsCredentialsProvider`]
297#[derive(Default, Debug, Clone)]
298pub struct Builder {
299    provider_config: Option<ProviderConfig>,
300    dns: Option<SharedDnsResolver>,
301    connect_timeout: Option<Duration>,
302    read_timeout: Option<Duration>,
303}
304
305impl Builder {
306    /// Override the configuration used for this provider
307    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
308        self.provider_config = Some(provider_config.clone());
309        self
310    }
311
312    /// Override the DNS resolver used to validate URIs
313    ///
314    /// URIs must refer to valid IP addresses as defined in the module documentation. The [`ResolveDns`]
315    /// implementation is used to retrieve IP addresses for a given domain.
316    pub fn dns(mut self, dns: impl ResolveDns + 'static) -> Self {
317        self.dns = Some(dns.into_shared());
318        self
319    }
320
321    /// Override the connect timeout for the HTTP client
322    ///
323    /// This value defaults to 2 seconds
324    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
325        self.connect_timeout = Some(timeout);
326        self
327    }
328
329    /// Override the read timeout for the HTTP client
330    ///
331    /// This value defaults to 5 seconds
332    pub fn read_timeout(mut self, timeout: Duration) -> Self {
333        self.read_timeout = Some(timeout);
334        self
335    }
336
337    /// Create an [`EcsCredentialsProvider`] from this builder
338    pub fn build(self) -> EcsCredentialsProvider {
339        let env = self
340            .provider_config
341            .as_ref()
342            .map(|config| config.env())
343            .unwrap_or_default();
344        let fs = self
345            .provider_config
346            .as_ref()
347            .map(|config| config.fs())
348            .unwrap_or_default();
349        EcsCredentialsProvider {
350            inner: OnceCell::new(),
351            env,
352            fs,
353            builder: self,
354        }
355    }
356}
357
358#[derive(Debug)]
359enum InvalidFullUriErrorKind {
360    /// The provided URI could not be parsed as a URI
361    #[non_exhaustive]
362    InvalidUri(InvalidUri),
363
364    /// No Dns resolver was provided
365    #[non_exhaustive]
366    NoDnsResolver,
367
368    /// The URI did not specify a host
369    #[non_exhaustive]
370    MissingHost,
371
372    /// The URI did not refer to an allowed IP address
373    #[non_exhaustive]
374    DisallowedIP,
375
376    /// DNS lookup failed when attempting to resolve the host to an IP Address for validation.
377    DnsLookupFailed(ResolveDnsError),
378}
379
380/// Invalid Full URI
381///
382/// When the full URI setting is used, the URI must either be HTTPS, point to a loopback interface,
383/// or point to known ECS/EKS container IPs.
384#[derive(Debug)]
385pub struct InvalidFullUriError {
386    kind: InvalidFullUriErrorKind,
387}
388
389impl Display for InvalidFullUriError {
390    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
391        use InvalidFullUriErrorKind::*;
392        match self.kind {
393            InvalidUri(_) => write!(f, "URI was invalid"),
394            MissingHost => write!(f, "URI did not specify a host"),
395            DisallowedIP => {
396                write!(f, "URI did not refer to an allowed IP address")
397            }
398            DnsLookupFailed(_) => {
399                write!(
400                    f,
401                    "failed to perform DNS lookup while validating URI"
402                )
403            }
404            NoDnsResolver => write!(f, "no DNS resolver was provided. Enable `rt-tokio` or provide a `dns` resolver to the builder.")
405        }
406    }
407}
408
409impl Error for InvalidFullUriError {
410    fn source(&self) -> Option<&(dyn Error + 'static)> {
411        use InvalidFullUriErrorKind::*;
412        match &self.kind {
413            InvalidUri(err) => Some(err),
414            DnsLookupFailed(err) => Some(err as _),
415            _ => None,
416        }
417    }
418}
419
420impl From<InvalidFullUriErrorKind> for InvalidFullUriError {
421    fn from(kind: InvalidFullUriErrorKind) -> Self {
422        Self { kind }
423    }
424}
425
426/// Validate that `uri` is valid to be used as a full provider URI
427/// Either:
428/// 1. The URL is uses `https`
429/// 2. The URL refers to an allowed IP. If a URL contains a domain name instead of an IP address,
430///    a DNS lookup will be performed. ALL resolved IP addresses MUST refer to an allowed IP, or
431///    the credentials provider will return `CredentialsError::InvalidConfiguration`. Allowed IPs
432///    are the loopback interfaces, and the known ECS/EKS container IPs.
433async fn validate_full_uri(
434    uri: &str,
435    dns: Option<SharedDnsResolver>,
436) -> Result<Uri, InvalidFullUriError> {
437    let uri = uri
438        .parse::<Uri>()
439        .map_err(InvalidFullUriErrorKind::InvalidUri)?;
440    if uri.scheme() == Some(&Scheme::HTTPS) {
441        return Ok(uri);
442    }
443    // For HTTP URIs, we need to validate that it points to a valid IP
444    let host = uri.host().ok_or(InvalidFullUriErrorKind::MissingHost)?;
445    let maybe_ip = if host.starts_with('[') && host.ends_with(']') {
446        host[1..host.len() - 1].parse::<IpAddr>()
447    } else {
448        host.parse::<IpAddr>()
449    };
450    let is_allowed = match maybe_ip {
451        Ok(addr) => is_full_uri_ip_allowed(&addr),
452        Err(_domain_name) => {
453            let dns = dns.ok_or(InvalidFullUriErrorKind::NoDnsResolver)?;
454            dns.resolve_dns(host)
455                .await
456                .map_err(|err| InvalidFullUriErrorKind::DnsLookupFailed(ResolveDnsError::new(err)))?
457                .iter()
458                    .all(|addr| {
459                        if !is_full_uri_ip_allowed(addr) {
460                            tracing::warn!(
461                                addr = ?addr,
462                                "HTTP credential provider cannot be used: Address does not resolve to an allowed IP."
463                            )
464                        };
465                        is_full_uri_ip_allowed(addr)
466                    })
467        }
468    };
469    match is_allowed {
470        true => Ok(uri),
471        false => Err(InvalidFullUriErrorKind::DisallowedIP.into()),
472    }
473}
474
475// "169.254.170.2"
476const ECS_CONTAINER_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(169, 254, 170, 2));
477
478// "169.254.170.23"
479const EKS_CONTAINER_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(169, 254, 170, 23));
480
481// "fd00:ec2::23"
482const EKS_CONTAINER_IPV6: IpAddr = IpAddr::V6(Ipv6Addr::new(0xFD00, 0x0EC2, 0, 0, 0, 0, 0, 0x23));
483fn is_full_uri_ip_allowed(ip: &IpAddr) -> bool {
484    ip.is_loopback()
485        || ip.eq(&ECS_CONTAINER_IPV4)
486        || ip.eq(&EKS_CONTAINER_IPV4)
487        || ip.eq(&EKS_CONTAINER_IPV6)
488}
489
490/// Default DNS resolver impl
491///
492/// DNS resolution is required to validate that provided URIs point to a valid IP address
493#[cfg(any(not(feature = "rt-tokio"), target_family = "wasm"))]
494fn default_dns() -> Option<SharedDnsResolver> {
495    None
496}
497#[cfg(all(feature = "rt-tokio", not(target_family = "wasm")))]
498fn default_dns() -> Option<SharedDnsResolver> {
499    use aws_smithy_runtime::client::dns::TokioDnsResolver;
500    Some(TokioDnsResolver::new().into_shared())
501}
502
503#[cfg(test)]
504mod test {
505    use super::*;
506    use crate::provider_config::ProviderConfig;
507    use crate::test_case::{no_traffic_client, GenericTestResult};
508    use aws_credential_types::provider::ProvideCredentials;
509    use aws_credential_types::Credentials;
510    use aws_smithy_async::future::never::Never;
511    use aws_smithy_async::rt::sleep::TokioSleep;
512    use aws_smithy_http_client::test_util::{ReplayEvent, StaticReplayClient};
513    use aws_smithy_runtime_api::client::dns::DnsFuture;
514    use aws_smithy_runtime_api::client::http::HttpClient;
515    use aws_smithy_runtime_api::shared::IntoShared;
516    use aws_smithy_types::body::SdkBody;
517    use aws_types::os_shim_internal::Env;
518    use futures_util::FutureExt;
519    use http::header::AUTHORIZATION;
520    use http::Uri;
521    use serde::Deserialize;
522    use std::collections::HashMap;
523    use std::error::Error;
524    use std::ffi::OsString;
525    use std::net::IpAddr;
526    use std::time::{Duration, UNIX_EPOCH};
527    use tracing_test::traced_test;
528
529    fn provider(
530        env: Env,
531        fs: Fs,
532        http_client: impl HttpClient + 'static,
533    ) -> EcsCredentialsProvider {
534        let provider_config = ProviderConfig::empty()
535            .with_env(env)
536            .with_fs(fs)
537            .with_http_client(http_client)
538            .with_sleep_impl(TokioSleep::new());
539        Builder::default().configure(&provider_config).build()
540    }
541
542    #[derive(Deserialize)]
543    struct EcsUriTest {
544        env: HashMap<String, String>,
545        result: GenericTestResult<String>,
546    }
547
548    impl EcsUriTest {
549        async fn check(&self) {
550            let env = Env::from(self.env.clone());
551            let uri = Provider::uri(env, Some(TestDns::default().into_shared()))
552                .await
553                .map(|uri| uri.to_string());
554            self.result.assert_matches(uri.as_ref());
555        }
556    }
557
558    #[tokio::test]
559    async fn run_config_tests() -> Result<(), Box<dyn Error>> {
560        let test_cases = std::fs::read_to_string("test-data/ecs-tests.json")?;
561        #[derive(Deserialize)]
562        struct TestCases {
563            tests: Vec<EcsUriTest>,
564        }
565
566        let test_cases: TestCases = serde_json::from_str(&test_cases)?;
567        let test_cases = test_cases.tests;
568        for test in test_cases {
569            test.check().await
570        }
571        Ok(())
572    }
573
574    #[test]
575    fn validate_uri_https() {
576        // over HTTPs, any URI is fine
577        let dns = Some(NeverDns.into_shared());
578        assert_eq!(
579            validate_full_uri("https://amazon.com", None)
580                .now_or_never()
581                .unwrap()
582                .expect("valid"),
583            Uri::from_static("https://amazon.com")
584        );
585        // over HTTP, it will try to lookup
586        assert!(
587            validate_full_uri("http://amazon.com", dns)
588                .now_or_never()
589                .is_none(),
590            "DNS lookup should occur, but it will never return"
591        );
592
593        let no_dns_error = validate_full_uri("http://amazon.com", None)
594            .now_or_never()
595            .unwrap()
596            .expect_err("DNS service is required");
597        assert!(
598            matches!(
599                no_dns_error,
600                InvalidFullUriError {
601                    kind: InvalidFullUriErrorKind::NoDnsResolver
602                }
603            ),
604            "expected no dns service, got: {}",
605            no_dns_error
606        );
607    }
608
609    #[test]
610    fn valid_uri_loopback() {
611        assert_eq!(
612            validate_full_uri("http://127.0.0.1:8080/get-credentials", None)
613                .now_or_never()
614                .unwrap()
615                .expect("valid uri"),
616            Uri::from_static("http://127.0.0.1:8080/get-credentials")
617        );
618
619        let err = validate_full_uri("http://192.168.10.120/creds", None)
620            .now_or_never()
621            .unwrap()
622            .expect_err("not a loopback");
623        assert!(matches!(
624            err,
625            InvalidFullUriError {
626                kind: InvalidFullUriErrorKind::DisallowedIP
627            }
628        ));
629    }
630
631    #[test]
632    fn valid_uri_ecs_eks() {
633        assert_eq!(
634            validate_full_uri("http://169.254.170.2:8080/get-credentials", None)
635                .now_or_never()
636                .unwrap()
637                .expect("valid uri"),
638            Uri::from_static("http://169.254.170.2:8080/get-credentials")
639        );
640        assert_eq!(
641            validate_full_uri("http://169.254.170.23:8080/get-credentials", None)
642                .now_or_never()
643                .unwrap()
644                .expect("valid uri"),
645            Uri::from_static("http://169.254.170.23:8080/get-credentials")
646        );
647        assert_eq!(
648            validate_full_uri("http://[fd00:ec2::23]:8080/get-credentials", None)
649                .now_or_never()
650                .unwrap()
651                .expect("valid uri"),
652            Uri::from_static("http://[fd00:ec2::23]:8080/get-credentials")
653        );
654
655        let err = validate_full_uri("http://169.254.171.23/creds", None)
656            .now_or_never()
657            .unwrap()
658            .expect_err("not an ecs/eks container address");
659        assert!(matches!(
660            err,
661            InvalidFullUriError {
662                kind: InvalidFullUriErrorKind::DisallowedIP
663            }
664        ));
665
666        let err = validate_full_uri("http://[fd00:ec2::2]/creds", None)
667            .now_or_never()
668            .unwrap()
669            .expect_err("not an ecs/eks container address");
670        assert!(matches!(
671            err,
672            InvalidFullUriError {
673                kind: InvalidFullUriErrorKind::DisallowedIP
674            }
675        ));
676    }
677
678    #[test]
679    fn all_addrs_local() {
680        let dns = Some(
681            TestDns::with_fallback(vec![
682                "127.0.0.1".parse().unwrap(),
683                "127.0.0.2".parse().unwrap(),
684                "169.254.170.23".parse().unwrap(),
685                "fd00:ec2::23".parse().unwrap(),
686            ])
687            .into_shared(),
688        );
689        let resp = validate_full_uri("http://localhost:8888", dns)
690            .now_or_never()
691            .unwrap();
692        assert!(resp.is_ok(), "Should be valid: {:?}", resp);
693    }
694
695    #[test]
696    fn all_addrs_not_local() {
697        let dns = Some(
698            TestDns::with_fallback(vec![
699                "127.0.0.1".parse().unwrap(),
700                "192.168.0.1".parse().unwrap(),
701            ])
702            .into_shared(),
703        );
704        let resp = validate_full_uri("http://localhost:8888", dns)
705            .now_or_never()
706            .unwrap();
707        assert!(
708            matches!(
709                resp,
710                Err(InvalidFullUriError {
711                    kind: InvalidFullUriErrorKind::DisallowedIP
712                })
713            ),
714            "Should be invalid: {:?}",
715            resp
716        );
717    }
718
719    fn creds_request(uri: &str, auth: Option<&str>) -> http::Request<SdkBody> {
720        let mut builder = http::Request::builder();
721        if let Some(auth) = auth {
722            builder = builder.header(AUTHORIZATION, auth);
723        }
724        builder.uri(uri).body(SdkBody::empty()).unwrap()
725    }
726
727    fn ok_creds_response() -> http::Response<SdkBody> {
728        http::Response::builder()
729            .status(200)
730            .body(SdkBody::from(
731                r#" {
732                       "AccessKeyId" : "AKID",
733                       "SecretAccessKey" : "SECRET",
734                       "Token" : "TOKEN....=",
735                       "AccountId" : "AID",
736                       "Expiration" : "2009-02-13T23:31:30Z"
737                     }"#,
738            ))
739            .unwrap()
740    }
741
742    #[track_caller]
743    fn assert_correct(creds: Credentials) {
744        assert_eq!(creds.access_key_id(), "AKID");
745        assert_eq!(creds.secret_access_key(), "SECRET");
746        assert_eq!(creds.account_id().unwrap().as_str(), "AID");
747        assert_eq!(creds.session_token().unwrap(), "TOKEN....=");
748        assert_eq!(
749            creds.expiry().unwrap(),
750            UNIX_EPOCH + Duration::from_secs(1234567890)
751        );
752    }
753
754    #[tokio::test]
755    async fn load_valid_creds_auth() {
756        let env = Env::from_slice(&[
757            ("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials"),
758            ("AWS_CONTAINER_AUTHORIZATION_TOKEN", "Basic password"),
759        ]);
760        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
761            creds_request("http://169.254.170.2/credentials", Some("Basic password")),
762            ok_creds_response(),
763        )]);
764        let provider = provider(env, Fs::default(), http_client.clone());
765        let creds = provider
766            .provide_credentials()
767            .await
768            .expect("valid credentials");
769        assert_correct(creds);
770        http_client.assert_requests_match(&[]);
771    }
772
773    #[tokio::test]
774    async fn load_valid_creds_auth_file() {
775        let env = Env::from_slice(&[
776            (
777                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
778                "http://169.254.170.23/v1/credentials",
779            ),
780            (
781                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
782                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
783            ),
784        ]);
785        let fs = Fs::from_raw_map(HashMap::from([(
786            OsString::from(
787                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
788            ),
789            "Basic password".into(),
790        )]));
791
792        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
793            creds_request(
794                "http://169.254.170.23/v1/credentials",
795                Some("Basic password"),
796            ),
797            ok_creds_response(),
798        )]);
799        let provider = provider(env, fs, http_client.clone());
800        let creds = provider
801            .provide_credentials()
802            .await
803            .expect("valid credentials");
804        assert_correct(creds);
805        http_client.assert_requests_match(&[]);
806    }
807
808    #[tokio::test]
809    #[traced_test]
810    async fn invalid_auth_token_env_does_not_log_value() {
811        let env = Env::from_slice(&[
812            ("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials"),
813            (
814                "AWS_CONTAINER_AUTHORIZATION_TOKEN",
815                "SECRET-MARKER-DO-NOT-LOG-abc123\n",
816            ),
817        ]);
818        let provider = provider(env, Fs::default(), no_traffic_client());
819        let err = provider
820            .provide_credentials()
821            .await
822            .expect_err("token with trailing newline should fail");
823        assert!(
824            matches!(err, CredentialsError::InvalidConfiguration { .. }),
825            "expected InvalidConfiguration, got: {:?}",
826            err
827        );
828        let error_display = format!("{}", DisplayErrorContext(&err));
829        assert!(
830            !error_display.contains("SECRET-MARKER"),
831            "error display must not contain the raw token value, got: {error_display}"
832        );
833        assert!(
834            !logs_contain("SECRET-MARKER"),
835            "logs must not contain the raw token value"
836        );
837    }
838
839    #[tokio::test]
840    #[traced_test]
841    async fn invalid_auth_token_file_does_not_log_value() {
842        let env = Env::from_slice(&[
843            (
844                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
845                "http://169.254.170.23/v1/credentials",
846            ),
847            (
848                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
849                "/eks-pod-identity-token",
850            ),
851        ]);
852        let fs = Fs::from_raw_map(HashMap::from([(
853            OsString::from("/eks-pod-identity-token"),
854            "SECRET-MARKER-DO-NOT-LOG-abc123\n".into(),
855        )]));
856        let provider = provider(env, fs, no_traffic_client());
857        let err = provider
858            .provide_credentials()
859            .await
860            .expect_err("token with trailing newline should fail");
861        assert!(
862            matches!(err, CredentialsError::InvalidConfiguration { .. }),
863            "expected InvalidConfiguration, got: {:?}",
864            err
865        );
866        let error_display = format!("{}", DisplayErrorContext(&err));
867        assert!(
868            !error_display.contains("SECRET-MARKER"),
869            "error display must not contain the raw token value, got: {error_display}"
870        );
871        assert!(
872            !logs_contain("SECRET-MARKER"),
873            "logs must not contain the raw token value"
874        );
875    }
876
877    #[tokio::test]
878    async fn auth_file_precedence_over_env() {
879        let env = Env::from_slice(&[
880            (
881                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
882                "http://169.254.170.23/v1/credentials",
883            ),
884            (
885                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
886                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
887            ),
888            ("AWS_CONTAINER_AUTHORIZATION_TOKEN", "unused"),
889        ]);
890        let fs = Fs::from_raw_map(HashMap::from([(
891            OsString::from(
892                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
893            ),
894            "Basic password".into(),
895        )]));
896
897        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
898            creds_request(
899                "http://169.254.170.23/v1/credentials",
900                Some("Basic password"),
901            ),
902            ok_creds_response(),
903        )]);
904        let provider = provider(env, fs, http_client.clone());
905        let creds = provider
906            .provide_credentials()
907            .await
908            .expect("valid credentials");
909        assert_correct(creds);
910        http_client.assert_requests_match(&[]);
911    }
912
913    #[tokio::test]
914    async fn query_params_should_be_included_in_credentials_http_request() {
915        let env = Env::from_slice(&[
916            (
917                "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
918                "/my-credentials/?applicationName=test2024",
919            ),
920            (
921                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
922                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
923            ),
924            ("AWS_CONTAINER_AUTHORIZATION_TOKEN", "unused"),
925        ]);
926        let fs = Fs::from_raw_map(HashMap::from([(
927            OsString::from(
928                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
929            ),
930            "Basic password".into(),
931        )]));
932
933        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
934            creds_request(
935                "http://169.254.170.2/my-credentials/?applicationName=test2024",
936                Some("Basic password"),
937            ),
938            ok_creds_response(),
939        )]);
940        let provider = provider(env, fs, http_client.clone());
941        let creds = provider
942            .provide_credentials()
943            .await
944            .expect("valid credentials");
945        assert_correct(creds);
946        http_client.assert_requests_match(&[]);
947    }
948
949    #[tokio::test]
950    async fn fs_missing_file() {
951        let env = Env::from_slice(&[
952            (
953                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
954                "http://169.254.170.23/v1/credentials",
955            ),
956            (
957                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
958                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
959            ),
960        ]);
961        let fs = Fs::from_raw_map(HashMap::new());
962
963        let provider = provider(env, fs, no_traffic_client());
964        let err = provider.credentials().await.expect_err("no JWT token file");
965        match err {
966            CredentialsError::ProviderError { .. } => { /* ok */ }
967            _ => panic!("incorrect error variant"),
968        }
969    }
970
971    #[tokio::test]
972    async fn retry_5xx() {
973        let env = Env::from_slice(&[("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials")]);
974        let http_client = StaticReplayClient::new(vec![
975            ReplayEvent::new(
976                creds_request("http://169.254.170.2/credentials", None),
977                http::Response::builder()
978                    .status(500)
979                    .body(SdkBody::empty())
980                    .unwrap(),
981            ),
982            ReplayEvent::new(
983                creds_request("http://169.254.170.2/credentials", None),
984                ok_creds_response(),
985            ),
986        ]);
987        tokio::time::pause();
988        let provider = provider(env, Fs::default(), http_client.clone());
989        let creds = provider
990            .provide_credentials()
991            .await
992            .expect("valid credentials");
993        assert_correct(creds);
994    }
995
996    #[tokio::test]
997    async fn load_valid_creds_no_auth() {
998        let env = Env::from_slice(&[("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials")]);
999        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
1000            creds_request("http://169.254.170.2/credentials", None),
1001            ok_creds_response(),
1002        )]);
1003        let provider = provider(env, Fs::default(), http_client.clone());
1004        let creds = provider
1005            .provide_credentials()
1006            .await
1007            .expect("valid credentials");
1008        assert_correct(creds);
1009        http_client.assert_requests_match(&[]);
1010    }
1011
1012    // ignored by default because it relies on actual DNS resolution
1013    #[allow(unused_attributes)]
1014    #[tokio::test]
1015    #[traced_test]
1016    #[ignore]
1017    async fn real_dns_lookup() {
1018        let dns = Some(
1019            default_dns()
1020                .expect("feature must be enabled")
1021                .into_shared(),
1022        );
1023        let err = validate_full_uri("http://www.amazon.com/creds", dns.clone())
1024            .await
1025            .expect_err("not a valid IP");
1026        assert!(
1027            matches!(
1028                err,
1029                InvalidFullUriError {
1030                    kind: InvalidFullUriErrorKind::DisallowedIP
1031                }
1032            ),
1033            "{:?}",
1034            err
1035        );
1036        assert!(logs_contain("Address does not resolve to an allowed IP"));
1037        validate_full_uri("http://localhost:8888/creds", dns.clone())
1038            .await
1039            .expect("localhost is the loopback interface");
1040        validate_full_uri("http://169.254.170.2.backname.io:8888/creds", dns.clone())
1041            .await
1042            .expect("169.254.170.2.backname.io is the ecs container address");
1043        validate_full_uri("http://169.254.170.23.backname.io:8888/creds", dns.clone())
1044            .await
1045            .expect("169.254.170.23.backname.io is the eks pod identity address");
1046        validate_full_uri("http://fd00-ec2--23.backname.io:8888/creds", dns)
1047            .await
1048            .expect("fd00-ec2--23.backname.io is the eks pod identity address");
1049    }
1050
1051    /// Always returns the same IP addresses
1052    #[derive(Clone, Debug)]
1053    struct TestDns {
1054        addrs: HashMap<String, Vec<IpAddr>>,
1055        fallback: Vec<IpAddr>,
1056    }
1057
1058    /// Default that returns a loopback for `localhost` and a non-loopback for all other hostnames
1059    impl Default for TestDns {
1060        fn default() -> Self {
1061            let mut addrs = HashMap::new();
1062            addrs.insert(
1063                "localhost".into(),
1064                vec!["127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap()],
1065            );
1066            TestDns {
1067                addrs,
1068                // non-loopback address
1069                fallback: vec!["72.21.210.29".parse().unwrap()],
1070            }
1071        }
1072    }
1073
1074    impl TestDns {
1075        fn with_fallback(fallback: Vec<IpAddr>) -> Self {
1076            TestDns {
1077                addrs: Default::default(),
1078                fallback,
1079            }
1080        }
1081    }
1082
1083    impl ResolveDns for TestDns {
1084        fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a> {
1085            DnsFuture::ready(Ok(self.addrs.get(name).unwrap_or(&self.fallback).clone()))
1086        }
1087    }
1088
1089    #[derive(Debug)]
1090    struct NeverDns;
1091    impl ResolveDns for NeverDns {
1092        fn resolve_dns<'a>(&'a self, _name: &'a str) -> DnsFuture<'a> {
1093            DnsFuture::new(async {
1094                Never::new().await;
1095                unreachable!()
1096            })
1097        }
1098    }
1099}