use aws_sdk_sts::error::AssumeRoleErrorKind;
use aws_sdk_sts::operation::AssumeRole;
use aws_types::credentials::{
self, future, CredentialsError, ProvideCredentials, SharedCredentialsProvider,
};
use aws_types::region::Region;
use crate::provider_config::HttpSettings;
use aws_smithy_async::rt::sleep::default_async_sleep;
use tracing::Instrument;
#[derive(Debug)]
pub struct AssumeRoleProvider {
sts: aws_hyper::StandardClient,
conf: aws_sdk_sts::Config,
op: aws_sdk_sts::input::AssumeRoleInput,
}
impl AssumeRoleProvider {
pub fn builder(role: impl Into<String>) -> AssumeRoleProviderBuilder {
AssumeRoleProviderBuilder::new(role.into())
}
}
pub struct AssumeRoleProviderBuilder {
role_arn: String,
external_id: Option<String>,
session_name: Option<String>,
region: Option<Region>,
connection: Option<aws_smithy_client::erase::DynConnector>,
}
impl AssumeRoleProviderBuilder {
pub fn new(role: impl Into<String>) -> Self {
Self {
role_arn: role.into(),
external_id: None,
session_name: None,
region: None,
connection: None,
}
}
pub fn external_id(mut self, id: impl Into<String>) -> Self {
self.external_id = Some(id.into());
self
}
pub fn session_name(mut self, name: impl Into<String>) -> Self {
self.session_name = Some(name.into());
self
}
pub fn region(mut self, region: Region) -> Self {
self.region = Some(region);
self
}
pub fn connection(mut self, conn: impl aws_smithy_client::bounds::SmithyConnector) -> Self {
self.connection = Some(aws_smithy_client::erase::DynConnector::new(conn));
self
}
pub fn build(self, provider: impl Into<SharedCredentialsProvider>) -> AssumeRoleProvider {
let config = aws_sdk_sts::Config::builder()
.credentials_provider(provider.into())
.region(self.region.clone())
.build();
let conn = self.connection.unwrap_or_else(|| {
crate::connector::expect_connector(crate::connector::default_connector(
&HttpSettings::default(),
default_async_sleep(),
))
});
let client = aws_hyper::Client::new(conn);
let session_name = self
.session_name
.unwrap_or_else(|| super::util::default_session_name("assume-role-provider"));
let operation = AssumeRole::builder()
.set_role_arn(Some(self.role_arn))
.set_external_id(self.external_id)
.set_role_session_name(Some(session_name))
.build()
.expect("operation is valid");
AssumeRoleProvider {
sts: client,
conf: config,
op: operation,
}
}
}
impl AssumeRoleProvider {
#[tracing::instrument(
name = "assume_role",
level = "info",
skip(self),
fields(op = ?self.op)
)]
async fn credentials(&self) -> credentials::Result {
tracing::info!("assuming role");
tracing::debug!("retrieving assumed credentials");
let op = self
.op
.clone()
.make_operation(&self.conf)
.await
.expect("valid operation");
let assumed = self.sts.call(op).in_current_span().await;
match assumed {
Ok(assumed) => {
tracing::debug!(
access_key_id = ?assumed.credentials.as_ref().map(|c| &c.access_key_id),
"obtained assumed credentials"
);
super::util::into_credentials(assumed.credentials, "AssumeRoleProvider")
}
Err(aws_hyper::SdkError::ServiceError { err, raw }) => {
match err.kind {
AssumeRoleErrorKind::RegionDisabledException(_)
| AssumeRoleErrorKind::MalformedPolicyDocumentException(_) => {
return Err(CredentialsError::invalid_configuration(
aws_hyper::SdkError::ServiceError { err, raw },
))
}
_ => {}
}
tracing::warn!(error = ?err.message(), "sts refused to grant assume role");
Err(CredentialsError::provider_error(
aws_hyper::SdkError::ServiceError { err, raw },
))
}
Err(err) => Err(CredentialsError::provider_error(err)),
}
}
}
impl ProvideCredentials for AssumeRoleProvider {
fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials
where
Self: 'a,
{
future::ProvideCredentials::new(self.credentials())
}
}