use super::StorageError;
use crate::http_client::Client;
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
const REFRESH_MARGIN_SECS: u64 = 120;
const IMDS_HOST: &str = "http://169.254.169.254";
const ECS_HOST: &str = "http://169.254.170.2";
const HTTP_TIMEOUT_MS: u64 = 5_000;
const IMDS_TIMEOUT_MS: u64 = 400;
#[derive(Clone, Debug, PartialEq)]
pub(super) struct Credentials {
pub access_key: String,
pub secret_key: String,
pub session_token: Option<String>,
pub expires_at_epoch_secs: Option<u64>,
}
impl Credentials {
fn is_fresh(&self, now_epoch_secs: u64, refresh_margin_secs: u64) -> bool {
match self.expires_at_epoch_secs {
None => true,
Some(exp) => now_epoch_secs + refresh_margin_secs < exp,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub(super) enum Source {
Static(Credentials),
Irsa { role_arn: String, token_file: String },
EcsRelative { path: String },
EcsFull { url: String, auth_token: Option<String> },
Imds,
}
fn detect_auto(static_access_key: Option<String>, static_secret_key: Option<String>) -> Source {
if let (Some(access_key), Some(secret_key)) = (static_access_key, static_secret_key) {
return Source::Static(Credentials {
access_key,
secret_key,
session_token: None,
expires_at_epoch_secs: None,
});
}
let role_arn = std::env::var("AWS_ROLE_ARN").ok();
let token_file = std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE").ok();
if let (Some(role_arn), Some(token_file)) = (role_arn, token_file) {
return Source::Irsa { role_arn, token_file };
}
if let Ok(url) = std::env::var("AWS_CONTAINER_CREDENTIALS_FULL_URI") {
return Source::EcsFull { url, auth_token: std::env::var("AWS_CONTAINER_AUTHORIZATION_TOKEN").ok() };
}
if let Ok(path) = std::env::var("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") {
return Source::EcsRelative { path };
}
Source::Imds
}
fn ecs_source_from_env() -> Source {
if let Ok(url) = std::env::var("AWS_CONTAINER_CREDENTIALS_FULL_URI") {
Source::EcsFull { url, auth_token: std::env::var("AWS_CONTAINER_AUTHORIZATION_TOKEN").ok() }
} else {
Source::EcsRelative { path: std::env::var("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI").unwrap_or_default() }
}
}
pub(super) struct CredentialsProvider {
source: Source,
region: String,
client: Client,
cached: Mutex<Option<Credentials>>,
}
impl CredentialsProvider {
pub(super) fn detect(region: &str, static_access_key: Option<String>, static_secret_key: Option<String>) -> Self {
let source = match std::env::var("RWS_S3_CREDENTIAL_SOURCE").ok().as_deref() {
Some("static") => Source::Static(Credentials {
access_key: static_access_key.unwrap_or_default(),
secret_key: static_secret_key.unwrap_or_default(),
session_token: None,
expires_at_epoch_secs: None,
}),
Some("irsa") => Source::Irsa {
role_arn: std::env::var("AWS_ROLE_ARN").unwrap_or_default(),
token_file: std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE").unwrap_or_default(),
},
Some("ecs") => ecs_source_from_env(),
Some("imds") => Source::Imds,
_ => detect_auto(static_access_key, static_secret_key),
};
Self::new(source, region)
}
fn new(source: Source, region: &str) -> Self {
CredentialsProvider { source, region: region.to_string(), client: Client::new(), cached: Mutex::new(None) }
}
pub(super) fn get(&self) -> Result<Credentials, StorageError> {
let now = epoch_now();
{
let guard = self.cached.lock().unwrap();
if let Some(c) = guard.as_ref() {
if c.is_fresh(now, REFRESH_MARGIN_SECS) {
return Ok(c.clone());
}
}
}
let fresh = self.fetch()?;
*self.cached.lock().unwrap() = Some(fresh.clone());
Ok(fresh)
}
fn fetch(&self) -> Result<Credentials, StorageError> {
match &self.source {
Source::Static(creds) => Ok(creds.clone()),
Source::Irsa { role_arn, token_file } => {
let sts_base_url = format!("https://sts.{}.amazonaws.com", self.region);
fetch_irsa_credentials(&self.client, &sts_base_url, role_arn, token_file)
}
Source::EcsRelative { path } => {
if path.is_empty() {
return Err(StorageError::new(
"ECS task role credentials requested but neither AWS_CONTAINER_CREDENTIALS_RELATIVE_URI nor _FULL_URI is set",
));
}
fetch_ecs_credentials(&self.client, &format!("{ECS_HOST}{path}"), None)
}
Source::EcsFull { url, auth_token } => fetch_ecs_credentials(&self.client, url, auth_token.as_deref()),
Source::Imds => fetch_imds_credentials(&self.client, IMDS_HOST),
}
}
}
fn epoch_now() -> u64 {
SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_secs()).unwrap_or(0)
}
fn fetch_irsa_credentials(client: &Client, sts_base_url: &str, role_arn: &str, token_file: &str) -> Result<Credentials, StorageError> {
if role_arn.is_empty() || token_file.is_empty() {
return Err(StorageError::new(
"IRSA credentials requested but AWS_ROLE_ARN and/or AWS_WEB_IDENTITY_TOKEN_FILE is not set",
));
}
let token = std::fs::read_to_string(token_file)
.map_err(|e| StorageError::new(format!("reading AWS_WEB_IDENTITY_TOKEN_FILE '{token_file}': {e}")))?;
let token = token.trim();
let url = format!(
"{sts_base_url}/?Action=AssumeRoleWithWebIdentity&Version=2011-06-15&RoleArn={}&WebIdentityToken={}&RoleSessionName=rws-s3",
url_search_params::encode_uri_component(role_arn),
url_search_params::encode_uri_component(token),
);
let resp = client
.get(&url)
.timeout_ms(HTTP_TIMEOUT_MS)
.send()
.map_err(|e| StorageError::new(format!("STS AssumeRoleWithWebIdentity request failed: {e}")))?;
if !resp.is_success() {
return Err(StorageError::new(format!(
"STS AssumeRoleWithWebIdentity failed: HTTP {} {}",
resp.status(),
resp.text().unwrap_or_default()
)));
}
let body = resp.text().map_err(|e| StorageError::new(format!("reading STS response: {e}")))?;
parse_sts_response(&body)
}
fn parse_sts_response(xml: &str) -> Result<Credentials, StorageError> {
let access_key =
extract_tag(xml, "AccessKeyId").ok_or_else(|| StorageError::new("STS AssumeRoleWithWebIdentity response missing AccessKeyId"))?;
let secret_key = extract_tag(xml, "SecretAccessKey")
.ok_or_else(|| StorageError::new("STS AssumeRoleWithWebIdentity response missing SecretAccessKey"))?;
let session_token = extract_tag(xml, "SessionToken")
.ok_or_else(|| StorageError::new("STS AssumeRoleWithWebIdentity response missing SessionToken"))?;
let expires_at_epoch_secs = Some(extract_tag(xml, "Expiration").and_then(|s| parse_iso8601_epoch(&s)).unwrap_or(0));
Ok(Credentials { access_key, secret_key, session_token: Some(session_token), expires_at_epoch_secs })
}
fn fetch_imds_credentials(client: &Client, imds_base_url: &str) -> Result<Credentials, StorageError> {
let token_resp = client
.put(&format!("{imds_base_url}/latest/api/token"))
.header("X-aws-ec2-metadata-token-ttl-seconds", "21600")
.timeout_ms(IMDS_TIMEOUT_MS)
.send()
.map_err(|e| StorageError::new(format!("IMDSv2 token request failed (not running on EC2? {e})")))?;
if !token_resp.is_success() {
return Err(StorageError::new(format!("IMDSv2 token request failed: HTTP {}", token_resp.status())));
}
let token = token_resp.text().map_err(|e| StorageError::new(format!("reading IMDSv2 token: {e}")))?;
let token = token.trim();
let role_list_resp = client
.get(&format!("{imds_base_url}/latest/meta-data/iam/security-credentials/"))
.header("X-aws-ec2-metadata-token", token)
.timeout_ms(IMDS_TIMEOUT_MS)
.send()
.map_err(|e| StorageError::new(format!("IMDSv2 role list request failed: {e}")))?;
if !role_list_resp.is_success() {
return Err(StorageError::new(format!("IMDSv2 role list request failed: HTTP {}", role_list_resp.status())));
}
let role_list = role_list_resp.text().map_err(|e| StorageError::new(format!("reading IMDSv2 role list: {e}")))?;
let role = role_list
.lines()
.next()
.map(str::trim)
.filter(|s| !s.is_empty())
.ok_or_else(|| StorageError::new("IMDSv2: no IAM role attached to this instance"))?;
let creds_resp = client
.get(&format!("{imds_base_url}/latest/meta-data/iam/security-credentials/{role}"))
.header("X-aws-ec2-metadata-token", token)
.timeout_ms(IMDS_TIMEOUT_MS)
.send()
.map_err(|e| StorageError::new(format!("IMDSv2 credentials request failed: {e}")))?;
if !creds_resp.is_success() {
return Err(StorageError::new(format!("IMDSv2 credentials request failed: HTTP {}", creds_resp.status())));
}
let body = creds_resp.text().map_err(|e| StorageError::new(format!("reading IMDSv2 credentials: {e}")))?;
parse_json_credentials(&body)
}
fn fetch_ecs_credentials(client: &Client, url: &str, auth_token: Option<&str>) -> Result<Credentials, StorageError> {
let mut builder = client.get(url).timeout_ms(HTTP_TIMEOUT_MS);
if let Some(tok) = auth_token {
builder = builder.header("Authorization", tok);
}
let resp = builder.send().map_err(|e| StorageError::new(format!("ECS task role credentials request failed: {e}")))?;
if !resp.is_success() {
return Err(StorageError::new(format!("ECS task role credentials request failed: HTTP {}", resp.status())));
}
let body = resp.text().map_err(|e| StorageError::new(format!("reading ECS task role credentials: {e}")))?;
parse_json_credentials(&body)
}
fn parse_json_credentials(json: &str) -> Result<Credentials, StorageError> {
let access_key = extract_json_str_field(json, "AccessKeyId").ok_or_else(|| StorageError::new("credentials response missing AccessKeyId"))?;
let secret_key =
extract_json_str_field(json, "SecretAccessKey").ok_or_else(|| StorageError::new("credentials response missing SecretAccessKey"))?;
let token = extract_json_str_field(json, "Token").ok_or_else(|| StorageError::new("credentials response missing Token"))?;
let expires_at_epoch_secs = Some(extract_json_str_field(json, "Expiration").and_then(|s| parse_iso8601_epoch(&s)).unwrap_or(0));
Ok(Credentials { access_key, secret_key, session_token: Some(token), expires_at_epoch_secs })
}
fn extract_tag(xml: &str, tag: &str) -> Option<String> {
let open = format!("<{tag}>");
let close = format!("</{tag}>");
let start = xml.find(&open)? + open.len();
let rest = &xml[start..];
let end = rest.find(&close)?;
Some(rest[..end].to_string())
}
fn extract_json_str_field(json: &str, field: &str) -> Option<String> {
let key = format!("\"{field}\"");
let start = json.find(key.as_str())?;
let rest = json[start + key.len()..].trim_start();
let rest = rest.strip_prefix(':')?.trim_start();
let rest = rest.strip_prefix('"')?;
Some(rest[..rest.find('"')?].to_string())
}
fn parse_iso8601_epoch(s: &str) -> Option<u64> {
let s = s.trim();
if s.len() < 19 {
return None;
}
let year: u32 = s.get(0..4)?.parse().ok()?;
let month: u32 = s.get(5..7)?.parse().ok()?;
let day: u32 = s.get(8..10)?.parse().ok()?;
let hour: u64 = s.get(11..13)?.parse().ok()?;
let min: u64 = s.get(14..16)?.parse().ok()?;
let sec: u64 = s.get(17..19)?.parse().ok()?;
let days = crate::scheduler::cron::ymd_to_days(year, month, day);
Some(days * 86400 + hour * 3600 + min * 60 + sec)
}
#[cfg(test)]
pub(crate) fn credential_env_lock() -> &'static std::sync::Mutex<()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
}
#[cfg(test)]
mod tests;