use std::borrow::Cow;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt::{Display, Formatter};
use std::str::FromStr;
use std::time::Duration;
use aws_http::user_agent::{ApiMetadata, AwsUserAgent, UserAgentStage};
use aws_smithy_client::{erase::DynConnector, timeout, SdkSuccess};
use aws_smithy_client::{retry, SdkError};
use aws_smithy_http::body::SdkBody;
use aws_smithy_http::endpoint::Endpoint;
use aws_smithy_http::operation;
use aws_smithy_http::operation::{Metadata, Operation};
use aws_smithy_http::response::ParseStrictResponse;
use aws_smithy_http::retry::ClassifyResponse;
use aws_smithy_http_tower::map_request::{
AsyncMapRequestLayer, AsyncMapRequestService, MapRequestLayer, MapRequestService,
};
use aws_smithy_types::retry::{ErrorKind, RetryKind};
use aws_smithy_types::timeout::TimeoutConfig;
use aws_types::os_shim_internal::{Env, Fs};
use bytes::Bytes;
use http::uri::InvalidUri;
use http::{Response, Uri};
use crate::connector::expect_connector;
use crate::imds::client::token::TokenMiddleware;
use crate::profile::ProfileParseError;
use crate::provider_config::{HttpSettings, ProviderConfig};
use crate::{profile, PKG_VERSION};
use tokio::sync::OnceCell;
mod token;
const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
const DEFAULT_ATTEMPTS: u32 = 4;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
fn user_agent() -> AwsUserAgent {
AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
}
#[derive(Debug)]
pub struct Client {
endpoint: Endpoint,
inner: aws_smithy_client::Client<DynConnector, ImdsMiddleware>,
}
#[derive(Debug)]
pub(super) struct LazyClient {
client: OnceCell<Result<Client, BuildError>>,
builder: Builder,
}
impl LazyClient {
pub fn from_ready_client(client: Client) -> Self {
Self {
client: OnceCell::from(Ok(client)),
builder: Builder::default(),
}
}
pub(super) async fn client(&self) -> Result<&Client, &BuildError> {
let builder = &self.builder;
self.client
.get_or_init(|| async {
let client = builder.clone().build().await;
if let Err(err) = &client {
tracing::warn!(err = % err, "failed to create IMDS client")
}
client
})
.await
.as_ref()
}
}
impl Client {
pub fn builder() -> Builder {
Builder::default()
}
pub async fn get(&self, path: &str) -> Result<String, ImdsError> {
let operation = self.make_operation(path)?;
self.inner.call(operation).await.map_err(|err| match err {
SdkError::ConstructionFailure(err) => match err.downcast::<ImdsError>() {
Ok(token_failure) => *token_failure,
Err(other) => ImdsError::Unexpected(other),
},
SdkError::TimeoutError(err) => ImdsError::IoError(err),
SdkError::DispatchFailure(err) => ImdsError::IoError(err.into()),
SdkError::ResponseError { err, .. } => ImdsError::IoError(err),
SdkError::ServiceError {
err: InnerImdsError::BadStatus,
raw,
} => ImdsError::ErrorResponse {
response: raw.into_parts().0,
},
SdkError::ServiceError {
err: InnerImdsError::InvalidUtf8,
..
} => ImdsError::Unexpected("IMDS returned invalid UTF-8".into()),
})
}
fn make_operation(
&self,
path: &str,
) -> Result<Operation<ImdsGetResponseHandler, ImdsErrorPolicy>, ImdsError> {
let mut base_uri: Uri = path.parse().map_err(|_| ImdsError::InvalidPath)?;
self.endpoint.set_endpoint(&mut base_uri, None);
let request = http::Request::builder()
.uri(base_uri)
.body(SdkBody::empty())
.expect("valid request");
let mut request = operation::Request::new(request);
request.properties_mut().insert(user_agent());
Ok(Operation::new(request, ImdsGetResponseHandler)
.with_metadata(Metadata::new("get", "imds"))
.with_retry_policy(ImdsErrorPolicy))
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum ImdsError {
FailedToLoadToken(SdkError<TokenError>),
InvalidPath,
#[non_exhaustive]
ErrorResponse {
response: http::Response<SdkBody>,
},
IoError(Box<dyn Error + Send + Sync + 'static>),
Unexpected(Box<dyn Error + Send + Sync + 'static>),
}
impl Display for ImdsError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ImdsError::FailedToLoadToken(inner) => {
write!(f, "Failed to load session token: {}", inner)
}
ImdsError::InvalidPath => write!(
f,
"IMDS path was not a valid URI. Hint: Does it begin with `/`?"
),
ImdsError::ErrorResponse { response } => write!(
f,
"Error response from IMDS (code: {}). {:?}",
response.status().as_u16(),
response
),
ImdsError::IoError(err) => {
write!(f, "An IO error occurred communicating with IMDS: {}", err)
}
ImdsError::Unexpected(err) => write!(
f,
"An unexpected error occurred communicating with IMDS: {}",
err
),
}
}
}
impl Error for ImdsError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match &self {
ImdsError::FailedToLoadToken(inner) => Some(inner),
_ => None,
}
}
}
#[derive(Clone, Debug)]
struct ImdsMiddleware {
token_loader: TokenMiddleware,
}
impl<S> tower::Layer<S> for ImdsMiddleware {
type Service = AsyncMapRequestService<MapRequestService<S, UserAgentStage>, TokenMiddleware>;
fn layer(&self, inner: S) -> Self::Service {
AsyncMapRequestLayer::for_mapper(self.token_loader.clone())
.layer(MapRequestLayer::for_mapper(UserAgentStage::new()).layer(inner))
}
}
#[derive(Copy, Clone)]
struct ImdsGetResponseHandler;
#[derive(Debug)]
enum InnerImdsError {
BadStatus,
InvalidUtf8,
}
impl Display for InnerImdsError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
InnerImdsError::BadStatus => write!(f, "failing status code returned from IMDS"),
InnerImdsError::InvalidUtf8 => write!(f, "IMDS did not return valid UTF-8"),
}
}
}
impl Error for InnerImdsError {}
impl ParseStrictResponse for ImdsGetResponseHandler {
type Output = Result<String, InnerImdsError>;
fn parse(&self, response: &Response<Bytes>) -> Self::Output {
if response.status().is_success() {
std::str::from_utf8(response.body().as_ref())
.map(|data| data.to_string())
.map_err(|_| InnerImdsError::InvalidUtf8)
} else {
Err(InnerImdsError::BadStatus)
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum EndpointMode {
IpV4,
IpV6,
}
#[derive(Debug, Clone)]
pub struct InvalidEndpointMode(String);
impl Display for InvalidEndpointMode {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"`{}` is not a valid endpoint mode. Valid values are [`IPv4`, `IPv6`]",
&self.0
)
}
}
impl Error for InvalidEndpointMode {}
impl FromStr for EndpointMode {
type Err = InvalidEndpointMode;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value {
_ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
_ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
other => Err(InvalidEndpointMode(other.to_owned())),
}
}
}
impl EndpointMode {
fn endpoint(&self) -> Uri {
match self {
EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
}
}
}
#[derive(Default, Debug, Clone)]
pub struct Builder {
max_attempts: Option<u32>,
endpoint: Option<EndpointSource>,
mode_override: Option<EndpointMode>,
token_ttl: Option<Duration>,
connect_timeout: Option<Duration>,
read_timeout: Option<Duration>,
config: Option<ProviderConfig>,
}
#[derive(Debug)]
pub enum BuildError {
InvalidEndpointMode(InvalidEndpointMode),
InvalidProfile(ProfileParseError),
InvalidEndpointUri(InvalidUri),
}
impl Display for BuildError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to build IMDS client: ")?;
match self {
BuildError::InvalidEndpointMode(e) => write!(f, "{}", e),
BuildError::InvalidProfile(e) => write!(f, "{}", e),
BuildError::InvalidEndpointUri(e) => write!(f, "{}", e),
}
}
}
impl Error for BuildError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
BuildError::InvalidEndpointMode(e) => Some(e),
BuildError::InvalidProfile(e) => Some(e),
BuildError::InvalidEndpointUri(e) => Some(e),
}
}
}
impl Builder {
pub fn max_attempts(mut self, max_attempts: u32) -> Self {
self.max_attempts = Some(max_attempts);
self
}
pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
self.config = Some(provider_config.clone());
self
}
pub fn endpoint(mut self, endpoint: impl Into<Uri>) -> Self {
self.endpoint = Some(EndpointSource::Explicit(endpoint.into()));
self
}
pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
self.mode_override = Some(mode);
self
}
pub fn token_ttl(mut self, ttl: Duration) -> Self {
self.token_ttl = Some(ttl);
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = Some(timeout);
self
}
pub fn read_timeout(mut self, timeout: Duration) -> Self {
self.read_timeout = Some(timeout);
self
}
pub(super) fn build_lazy(self) -> LazyClient {
LazyClient {
client: OnceCell::new(),
builder: self,
}
}
pub async fn build(self) -> Result<Client, BuildError> {
let config = self.config.unwrap_or_default();
let timeout_config = timeout::Settings::default()
.with_connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
.with_read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT));
let connector = expect_connector(config.connector(&HttpSettings {
timeout_settings: timeout_config,
}));
let endpoint_source = self
.endpoint
.unwrap_or_else(|| EndpointSource::Env(config.env(), config.fs()));
let endpoint = endpoint_source.endpoint(self.mode_override).await?;
let endpoint = Endpoint::immutable(endpoint);
let retry_config = retry::Config::default()
.with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
let timeout_config = TimeoutConfig::default();
let token_loader = token::TokenMiddleware::new(
connector.clone(),
config.time_source(),
endpoint.clone(),
self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
retry_config.clone(),
timeout_config.clone(),
);
let middleware = ImdsMiddleware { token_loader };
let inner_client = aws_smithy_client::Builder::new()
.connector(connector.clone())
.middleware(middleware)
.build()
.with_retry_config(retry_config)
.with_timeout_config(timeout_config);
let client = Client {
endpoint,
inner: inner_client,
};
Ok(client)
}
}
mod env {
pub const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
pub const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
}
mod profile_keys {
pub const ENDPOINT: &str = "ec2_metadata_service_endpoint";
pub const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
}
#[derive(Debug, Clone)]
enum EndpointSource {
Explicit(Uri),
Env(Env, Fs),
}
impl EndpointSource {
async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
match self {
EndpointSource::Explicit(uri) => {
if mode_override.is_some() {
tracing::warn!(endpoint = ?uri, mode = ?mode_override,
"Endpoint mode override was set in combination with an explicit endpoint. \
The mode override will be ignored.")
}
Ok(uri.clone())
}
EndpointSource::Env(env, fs) => {
let profile = profile::load(fs, env)
.await
.map_err(BuildError::InvalidProfile)?;
let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
Some(Cow::Owned(uri))
} else {
profile.get(profile_keys::ENDPOINT).map(Cow::Borrowed)
};
if let Some(uri) = uri_override {
return Uri::try_from(uri.as_ref()).map_err(BuildError::InvalidEndpointUri);
}
let mode = if let Some(mode) = mode_override {
mode
} else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
mode.parse::<EndpointMode>()
.map_err(BuildError::InvalidEndpointMode)?
} else if let Some(mode) = profile.get(profile_keys::ENDPOINT_MODE) {
mode.parse::<EndpointMode>()
.map_err(BuildError::InvalidEndpointMode)?
} else {
EndpointMode::IpV4
};
Ok(mode.endpoint())
}
}
}
}
#[derive(Debug)]
pub enum TokenError {
InvalidToken,
NoTtl,
InvalidTtl,
InvalidParameters,
Forbidden,
}
impl Display for TokenError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
TokenError::InvalidToken => write!(f, "Invalid Token"),
TokenError::NoTtl => write!(f, "Token response did not contain a TTL header"),
TokenError::InvalidTtl => write!(f, "The returned TTL was invalid"),
TokenError::InvalidParameters => {
write!(f, "Invalid request parameters. This indicates an SDK bug.")
}
TokenError::Forbidden => write!(
f,
"Request forbidden: IMDS is disabled or the caller has insufficient permissions."
),
}
}
}
impl Error for TokenError {}
#[derive(Clone)]
struct ImdsErrorPolicy;
impl ImdsErrorPolicy {
fn classify(response: &operation::Response) -> RetryKind {
let status = response.http().status();
match status {
_ if status.is_server_error() => RetryKind::Error(ErrorKind::ServerError),
_ if status.as_u16() == 401 => RetryKind::Error(ErrorKind::ServerError),
_ => RetryKind::NotRetryable,
}
}
}
impl<T, E> ClassifyResponse<SdkSuccess<T>, SdkError<E>> for ImdsErrorPolicy {
fn classify(&self, response: Result<&SdkSuccess<T>, &SdkError<E>>) -> RetryKind {
match response {
Ok(_) => RetryKind::NotRetryable,
Err(SdkError::ResponseError { raw, .. }) | Err(SdkError::ServiceError { raw, .. }) => {
ImdsErrorPolicy::classify(raw)
}
_ => RetryKind::NotRetryable,
}
}
}
#[cfg(test)]
pub(crate) mod test {
use std::collections::HashMap;
use std::error::Error;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use aws_hyper::DynConnector;
use aws_smithy_client::test_connection::{capture_request, TestConnection};
use aws_smithy_http::body::SdkBody;
use aws_types::os_shim_internal::{Env, Fs, ManualTimeSource, TimeSource};
use http::Uri;
use serde::Deserialize;
use tracing_test::traced_test;
use crate::imds::client::{Client, EndpointMode, ImdsError};
use crate::provider_config::ProviderConfig;
use http::header::USER_AGENT;
const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
const TOKEN_B: &str = "alternatetoken==";
pub(crate) fn token_request(base: &str, ttl: u32) -> http::Request<SdkBody> {
http::Request::builder()
.uri(format!("{}/latest/api/token", base))
.header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
.method("PUT")
.body(SdkBody::empty())
.unwrap()
}
pub(crate) fn token_response(ttl: u32, token: &'static str) -> http::Response<&'static str> {
http::Response::builder()
.status(200)
.header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
.body(token)
.unwrap()
}
pub(crate) fn imds_request(path: &'static str, token: &str) -> http::Request<SdkBody> {
http::Request::builder()
.uri(Uri::from_static(path))
.method("GET")
.header("x-aws-ec2-metadata-token", token)
.body(SdkBody::empty())
.unwrap()
}
pub(crate) fn imds_response(body: &'static str) -> http::Response<&'static str> {
http::Response::builder().status(200).body(body).unwrap()
}
async fn make_client<T>(conn: &TestConnection<T>) -> super::Client
where
SdkBody: From<T>,
T: Send + 'static,
{
super::Client::builder()
.configure(
&ProviderConfig::no_configuration()
.with_http_connector(DynConnector::new(conn.clone())),
)
.build()
.await
.expect("valid client")
}
#[tokio::test]
async fn client_caches_token() {
let connection = TestConnection::new(vec![
(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output"#),
),
(
imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
imds_response("output2"),
),
]);
let client = make_client(&connection).await;
let metadata = client.get("/latest/metadata").await.expect("failed");
assert_eq!(metadata, "test-imds-output");
let metadata = client.get("/latest/metadata2").await.expect("failed");
assert_eq!(metadata, "output2");
connection.assert_requests_match(&[]);
}
#[tokio::test]
async fn token_can_expire() {
let connection = TestConnection::new(vec![
(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_A),
),
(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output1"#),
),
(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_B),
),
(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
imds_response(r#"test-imds-output2"#),
),
]);
let mut time_source = ManualTimeSource::new(UNIX_EPOCH);
let client = super::Client::builder()
.configure(
&ProviderConfig::no_configuration()
.with_http_connector(DynConnector::new(connection.clone()))
.with_time_source(TimeSource::manual(&time_source)),
)
.endpoint_mode(EndpointMode::IpV6)
.token_ttl(Duration::from_secs(600))
.build()
.await
.expect("valid client");
let resp1 = client.get("/latest/metadata").await.expect("success");
time_source.advance(Duration::from_secs(600));
let resp2 = client.get("/latest/metadata").await.expect("success");
connection.assert_requests_match(&[]);
assert_eq!(resp1, "test-imds-output1");
assert_eq!(resp2, "test-imds-output2");
}
#[tokio::test]
async fn token_refresh_buffer() {
let connection = TestConnection::new(vec![
(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_A),
),
(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output1"#),
),
(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output2"#),
),
(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_B),
),
(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
imds_response(r#"test-imds-output3"#),
),
]);
let mut time_source = ManualTimeSource::new(UNIX_EPOCH);
let client = super::Client::builder()
.configure(
&ProviderConfig::no_configuration()
.with_http_connector(DynConnector::new(connection.clone()))
.with_time_source(TimeSource::manual(&time_source)),
)
.endpoint_mode(EndpointMode::IpV6)
.token_ttl(Duration::from_secs(600))
.build()
.await
.expect("valid client");
let resp1 = client.get("/latest/metadata").await.expect("success");
time_source.advance(Duration::from_secs(400));
let resp2 = client.get("/latest/metadata").await.expect("success");
time_source.advance(Duration::from_secs(150));
let resp3 = client.get("/latest/metadata").await.expect("success");
connection.assert_requests_match(&[]);
assert_eq!(resp1, "test-imds-output1");
assert_eq!(resp2, "test-imds-output2");
assert_eq!(resp3, "test-imds-output3");
}
#[tokio::test]
#[traced_test]
async fn retry_500() {
let connection = TestConnection::new(vec![
(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
http::Response::builder().status(500).body("").unwrap(),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
imds_response("ok"),
),
]);
let client = make_client(&connection).await;
assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok");
connection.assert_requests_match(&[]);
for request in connection.requests().iter() {
assert!(request.actual.headers().get(USER_AGENT).is_some());
}
}
#[tokio::test]
#[traced_test]
async fn retry_token_failure() {
let connection = TestConnection::new(vec![
(
token_request("http://169.254.169.254", 21600),
http::Response::builder().status(500).body("").unwrap(),
),
(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
imds_response("ok"),
),
]);
let client = make_client(&connection).await;
assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok");
connection.assert_requests_match(&[]);
}
#[tokio::test]
#[traced_test]
async fn no_403_retry() {
let connection = TestConnection::new(vec![(
token_request("http://169.254.169.254", 21600),
http::Response::builder().status(403).body("").unwrap(),
)]);
let client = make_client(&connection).await;
let err = client.get("/latest/metadata").await.expect_err("no token");
assert!(format!("{}", err).contains("forbidden"), "{}", err);
connection.assert_requests_match(&[]);
}
#[tokio::test]
async fn invalid_token() {
let connection = TestConnection::new(vec![(
token_request("http://169.254.169.254", 21600),
token_response(21600, "replaced").map(|_| vec![1, 0]),
)]);
let client = make_client(&connection).await;
let err = client.get("/latest/metadata").await.expect_err("no token");
assert!(format!("{}", err).contains("Invalid Token"), "{}", err);
connection.assert_requests_match(&[]);
}
#[tokio::test]
async fn non_utf8_response() {
let connection = TestConnection::new(vec![
(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A).map(SdkBody::from),
),
(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
http::Response::builder()
.status(200)
.body(SdkBody::from(vec![0xA0 as u8, 0xA1 as u8]))
.unwrap(),
),
]);
let client = make_client(&connection).await;
let err = client.get("/latest/metadata").await.expect_err("no token");
assert!(format!("{}", err).contains("invalid UTF-8"), "{}", err);
connection.assert_requests_match(&[]);
}
#[ignore]
#[tokio::test]
async fn one_second_connect_timeout() {
let client = Client::builder()
.endpoint(Uri::from_static("http://240.0.0.0"))
.build()
.await
.expect("valid client");
let now = SystemTime::now();
let resp = client
.get("/latest/metadata")
.await
.expect_err("240.0.0.0 will never resolve");
assert!(now.elapsed().unwrap() > Duration::from_secs(1));
assert!(now.elapsed().unwrap() < Duration::from_secs(2));
match resp {
ImdsError::FailedToLoadToken(err) if format!("{}", err).contains("timed out") => {} other => panic!(
"wrong error, expected construction failure with TimedOutError inside: {}",
other
),
}
}
#[derive(Debug, Deserialize)]
struct ImdsConfigTest {
env: HashMap<String, String>,
fs: HashMap<String, String>,
endpoint_override: Option<String>,
mode_override: Option<String>,
result: Result<String, String>,
docs: String,
}
#[tokio::test]
async fn config_tests() -> Result<(), Box<dyn Error>> {
let test_cases = std::fs::read_to_string("test-data/imds-config/imds-tests.json")?;
#[derive(Deserialize)]
struct TestCases {
tests: Vec<ImdsConfigTest>,
}
let test_cases: TestCases = serde_json::from_str(&test_cases)?;
let test_cases = test_cases.tests;
for test in test_cases {
check(test).await;
}
Ok(())
}
async fn check(test_case: ImdsConfigTest) {
let (server, watcher) = capture_request(None);
let provider_config = ProviderConfig::no_configuration()
.with_env(Env::from(test_case.env))
.with_fs(Fs::from_map(test_case.fs))
.with_http_connector(DynConnector::new(server));
let mut imds_client = Client::builder().configure(&provider_config);
if let Some(endpoint_override) = test_case.endpoint_override {
imds_client = imds_client.endpoint(endpoint_override.parse::<Uri>().unwrap());
}
if let Some(mode_override) = test_case.mode_override {
imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
}
let imds_client = imds_client.build().await;
let (uri, imds_client) = match (&test_case.result, imds_client) {
(Ok(uri), Ok(client)) => (uri, client),
(Err(test), Ok(_client)) => panic!(
"test should fail: {} but a valid client was made. {}",
test, test_case.docs
),
(Err(substr), Err(err)) => {
assert!(
format!("{}", err).contains(substr),
"`{}` did not contain `{}`",
err,
substr
);
return;
}
(Ok(_uri), Err(e)) => panic!(
"a valid client should be made but: {}. {}",
e, test_case.docs
),
};
let _ = imds_client.get("/hello").await;
assert_eq!(&watcher.expect_request().uri().to_string(), uri);
}
}