use std::sync::Arc;
use chrono::{DateTime, Utc};
use reqwest::header::CONTENT_TYPE;
use serde::Deserialize;
use url::Url;
use crate::authn::factor::ZeroizedString;
const STS_API_VERSION: &str = "2011-06-15";
fn default_endpoint() -> Url {
Url::parse("https://sts.amazonaws.com/").expect("default AWS STS endpoint is valid")
}
#[derive(Debug, Clone)]
pub struct AssumeRoleWithWebIdentityRequest {
pub role_arn: String,
pub role_session_name: String,
pub web_identity_token: String,
pub duration_seconds: Option<u32>,
pub policy: Option<String>,
pub policy_arns: Vec<String>,
pub provider_id: Option<String>,
}
impl AssumeRoleWithWebIdentityRequest {
pub fn new(
role_arn: impl Into<String>,
role_session_name: impl Into<String>,
web_identity_token: impl Into<String>,
) -> Self {
Self {
role_arn: role_arn.into(),
role_session_name: role_session_name.into(),
web_identity_token: web_identity_token.into(),
duration_seconds: None,
policy: None,
policy_arns: Vec::new(),
provider_id: None,
}
}
pub fn with_duration_seconds(mut self, duration: u32) -> Self {
self.duration_seconds = Some(duration);
self
}
pub fn with_policy(mut self, policy: impl Into<String>) -> Self {
self.policy = Some(policy.into());
self
}
pub fn with_policy_arn(mut self, arn: impl Into<String>) -> Self {
self.policy_arns.push(arn.into());
self
}
pub fn with_provider_id(mut self, provider_id: impl Into<String>) -> Self {
self.provider_id = Some(provider_id.into());
self
}
}
#[derive(Debug)]
pub struct AwsCredentials {
pub access_key_id: ZeroizedString,
pub secret_access_key: ZeroizedString,
pub session_token: ZeroizedString,
pub expiration: DateTime<Utc>,
}
#[derive(Debug)]
pub struct AssumeRoleWithWebIdentityResponse {
pub credentials: AwsCredentials,
pub assumed_role_arn: String,
pub assumed_role_id: String,
pub subject_from_web_identity_token: String,
pub provider: Option<String>,
pub audience: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum AwsStsError {
#[error("AWS STS transport error: {0}")]
Transport(String),
#[error("AWS STS error [{code}] (HTTP {http_status}): {message}")]
StsError {
http_status: u16,
code: String,
message: String,
fault_type: Option<String>,
request_id: Option<String>,
},
#[error("malformed AWS STS response: {0}")]
MalformedResponse(String),
}
#[derive(Clone)]
pub struct AwsStsClient {
endpoint: Arc<Url>,
http: reqwest::Client,
}
impl Default for AwsStsClient {
fn default() -> Self {
Self::new()
}
}
impl AwsStsClient {
pub fn new() -> Self {
Self {
endpoint: Arc::new(default_endpoint()),
http: reqwest::Client::new(),
}
}
pub fn with_endpoint(mut self, endpoint: Url) -> Self {
self.endpoint = Arc::new(endpoint);
self
}
pub fn with_http_client(mut self, http: reqwest::Client) -> Self {
self.http = http;
self
}
pub fn endpoint(&self) -> &Url {
&self.endpoint
}
pub async fn assume_role_with_web_identity(
&self,
request: &AssumeRoleWithWebIdentityRequest,
) -> Result<AssumeRoleWithWebIdentityResponse, AwsStsError> {
let form = build_form(request);
let response = self
.http
.post((*self.endpoint).clone())
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.form(&form)
.send()
.await
.map_err(|e| AwsStsError::Transport(e.to_string()))?;
let status = response.status();
let body = response
.text()
.await
.map_err(|e| AwsStsError::Transport(e.to_string()))?;
if !status.is_success() {
return Err(parse_error(status.as_u16(), &body));
}
parse_success(&body)
}
}
fn build_form(request: &AssumeRoleWithWebIdentityRequest) -> Vec<(String, String)> {
let mut form: Vec<(String, String)> = vec![
(
"Action".to_string(),
"AssumeRoleWithWebIdentity".to_string(),
),
("Version".to_string(), STS_API_VERSION.to_string()),
("RoleArn".to_string(), request.role_arn.clone()),
(
"RoleSessionName".to_string(),
request.role_session_name.clone(),
),
(
"WebIdentityToken".to_string(),
request.web_identity_token.clone(),
),
];
if let Some(d) = request.duration_seconds {
form.push(("DurationSeconds".to_string(), d.to_string()));
}
if let Some(policy) = &request.policy {
form.push(("Policy".to_string(), policy.clone()));
}
for (i, arn) in request.policy_arns.iter().enumerate() {
form.push((format!("PolicyArns.member.{}.arn", i + 1), arn.clone()));
}
if let Some(pid) = &request.provider_id {
form.push(("ProviderId".to_string(), pid.clone()));
}
form
}
#[derive(Debug, Deserialize)]
struct StsSuccessEnvelope {
#[serde(rename = "AssumeRoleWithWebIdentityResult")]
result: StsAssumeRoleResult,
}
#[derive(Debug, Deserialize)]
struct StsAssumeRoleResult {
#[serde(rename = "Credentials")]
credentials: StsCredentialsXml,
#[serde(rename = "AssumedRoleUser")]
assumed_role_user: StsAssumedRoleUserXml,
#[serde(rename = "SubjectFromWebIdentityToken")]
subject_from_web_identity_token: String,
#[serde(rename = "Provider", default)]
provider: Option<String>,
#[serde(rename = "Audience", default)]
audience: Option<String>,
}
#[derive(Debug, Deserialize)]
struct StsCredentialsXml {
#[serde(rename = "AccessKeyId")]
access_key_id: String,
#[serde(rename = "SecretAccessKey")]
secret_access_key: String,
#[serde(rename = "SessionToken")]
session_token: String,
#[serde(rename = "Expiration")]
expiration: String,
}
#[derive(Debug, Deserialize)]
struct StsAssumedRoleUserXml {
#[serde(rename = "Arn")]
arn: String,
#[serde(rename = "AssumedRoleId")]
assumed_role_id: String,
}
fn parse_success(body: &str) -> Result<AssumeRoleWithWebIdentityResponse, AwsStsError> {
let parsed: StsSuccessEnvelope = quick_xml::de::from_str(body)
.map_err(|e| AwsStsError::MalformedResponse(format!("XML decode: {e}")))?;
let creds = parsed.result.credentials;
let expiration = DateTime::parse_from_rfc3339(&creds.expiration)
.map_err(|e| AwsStsError::MalformedResponse(format!("Expiration not RFC3339: {e}")))?
.with_timezone(&Utc);
if creds.access_key_id.is_empty()
|| creds.secret_access_key.is_empty()
|| creds.session_token.is_empty()
{
return Err(AwsStsError::MalformedResponse(
"credentials block missing one of access_key_id / secret_access_key / session_token"
.to_string(),
));
}
Ok(AssumeRoleWithWebIdentityResponse {
credentials: AwsCredentials {
access_key_id: ZeroizedString::from(creds.access_key_id),
secret_access_key: ZeroizedString::from(creds.secret_access_key),
session_token: ZeroizedString::from(creds.session_token),
expiration,
},
assumed_role_arn: parsed.result.assumed_role_user.arn,
assumed_role_id: parsed.result.assumed_role_user.assumed_role_id,
subject_from_web_identity_token: parsed.result.subject_from_web_identity_token,
provider: parsed.result.provider,
audience: parsed.result.audience,
})
}
#[derive(Debug, Deserialize)]
struct StsErrorEnvelope {
#[serde(rename = "Error")]
error: StsErrorBody,
#[serde(rename = "RequestId", default)]
request_id: Option<String>,
}
#[derive(Debug, Deserialize)]
struct StsErrorBody {
#[serde(rename = "Type", default)]
fault_type: Option<String>,
#[serde(rename = "Code")]
code: String,
#[serde(rename = "Message")]
message: String,
}
fn parse_error(http_status: u16, body: &str) -> AwsStsError {
match quick_xml::de::from_str::<StsErrorEnvelope>(body) {
Ok(env) => AwsStsError::StsError {
http_status,
code: env.error.code,
message: env.error.message,
fault_type: env.error.fault_type,
request_id: env.request_id,
},
Err(_) => AwsStsError::StsError {
http_status,
code: "Unknown".to_string(),
message: format!("non-XML error body: {body}"),
fault_type: None,
request_id: None,
},
}
}
#[cfg(test)]
mod tests;