use std::sync::Arc;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use http::StatusCode;
use fakecloud_core::service::{AwsRequest, AwsResponse, AwsService, AwsServiceError};
use fakecloud_core::validation::*;
use fakecloud_persistence::SnapshotStore;
use crate::evaluator::{
evaluate_resource_policy_only, Decision, EvalRequest, PolicyDocument, RequestContext,
};
use crate::persistence::{save_iam_snapshot, IamSnapshotLock};
use crate::state::{CredentialIdentity, IamState, SharedIamState, StsTempCredential};
use crate::xml_responses::{self, StsCredentials};
use fakecloud_core::auth::{Principal, PrincipalType};
const DEFAULT_ASSUME_ROLE_DURATION: i64 = 3600;
const DEFAULT_SESSION_TOKEN_DURATION: i64 = 43200;
const DEFAULT_FEDERATION_TOKEN_DURATION: i64 = 43200;
fn compute_expiration_at(
req: &AwsRequest,
default_duration: i64,
) -> Result<DateTime<Utc>, AwsServiceError> {
let duration = if let Some(ds) = req.query_params.get("DurationSeconds") {
ds.parse::<i64>().map_err(|_| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationError",
format!(
"Value '{}' at 'durationSeconds' failed to satisfy constraint: \
Member must be a valid integer",
ds
),
)
})?
} else {
default_duration
};
Ok(Utc::now() + chrono::Duration::seconds(duration))
}
fn format_expiration(ts: DateTime<Utc>) -> String {
ts.format("%Y-%m-%dT%H:%M:%SZ").to_string()
}
#[cfg(test)]
fn compute_expiration(req: &AwsRequest, default_duration: i64) -> Result<String, AwsServiceError> {
Ok(format_expiration(compute_expiration_at(
req,
default_duration,
)?))
}
pub struct StsService {
state: SharedIamState,
snapshot_store: Option<Arc<dyn SnapshotStore>>,
snapshot_lock: IamSnapshotLock,
}
impl StsService {
pub fn new(state: SharedIamState) -> Self {
Self {
state,
snapshot_store: None,
snapshot_lock: crate::persistence::new_snapshot_lock(),
}
}
pub fn with_snapshot_store(mut self, store: Arc<dyn SnapshotStore>) -> Self {
self.snapshot_store = Some(store);
self
}
pub fn with_snapshot_lock(mut self, lock: IamSnapshotLock) -> Self {
self.snapshot_lock = lock;
self
}
}
fn is_mutating_action(action: &str) -> bool {
matches!(
action,
"AssumeRole"
| "AssumeRoleWithWebIdentity"
| "AssumeRoleWithSAML"
| "GetSessionToken"
| "GetFederationToken"
| "AssumeRoot"
)
}
#[async_trait]
impl AwsService for StsService {
fn service_name(&self) -> &str {
"sts"
}
async fn handle(&self, req: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let mutates = is_mutating_action(req.action.as_str());
let result = match req.action.as_str() {
"GetCallerIdentity" => self.get_caller_identity(&req),
"AssumeRole" => self.assume_role(&req),
"AssumeRoleWithWebIdentity" => self.assume_role_with_web_identity(&req),
"AssumeRoleWithSAML" => self.assume_role_with_saml(&req),
"GetSessionToken" => self.get_session_token(&req),
"GetFederationToken" => self.get_federation_token(&req),
"GetAccessKeyInfo" => self.get_access_key_info(&req),
"DecodeAuthorizationMessage" => self.decode_authorization_message(&req),
"AssumeRoot" => self.assume_root(&req),
_ => Err(AwsServiceError::action_not_implemented("sts", &req.action)),
};
if mutates && matches!(result.as_ref(), Ok(resp) if resp.status.is_success()) {
save_iam_snapshot(
&self.state,
self.snapshot_store.clone(),
&self.snapshot_lock,
)
.await;
}
result
}
fn supported_actions(&self) -> &[&str] {
&[
"GetCallerIdentity",
"AssumeRole",
"AssumeRoleWithWebIdentity",
"AssumeRoleWithSAML",
"GetSessionToken",
"GetFederationToken",
"GetAccessKeyInfo",
"DecodeAuthorizationMessage",
"AssumeRoot",
]
}
fn iam_enforceable(&self) -> bool {
true
}
fn iam_action_for(
&self,
request: &fakecloud_core::service::AwsRequest,
) -> Option<fakecloud_core::auth::IamAction> {
let action: &'static str = match request.action.as_str() {
"GetCallerIdentity" => "GetCallerIdentity",
"AssumeRole" => "AssumeRole",
"AssumeRoleWithWebIdentity" => "AssumeRoleWithWebIdentity",
"AssumeRoleWithSAML" => "AssumeRoleWithSAML",
"GetSessionToken" => "GetSessionToken",
"GetFederationToken" => "GetFederationToken",
"GetAccessKeyInfo" => "GetAccessKeyInfo",
"DecodeAuthorizationMessage" => "DecodeAuthorizationMessage",
"AssumeRoot" => "AssumeRoot",
_ => return None,
};
let resource = match action {
"AssumeRole" | "AssumeRoleWithWebIdentity" | "AssumeRoleWithSAML" => request
.query_params
.get("RoleArn")
.cloned()
.unwrap_or_else(|| "*".to_string()),
"AssumeRoot" => request
.query_params
.get("TargetPrincipal")
.cloned()
.unwrap_or_else(|| "*".to_string()),
_ => "*".to_string(),
};
Some(fakecloud_core::auth::IamAction {
service: "sts",
action,
resource,
})
}
}
fn partition_for_region(region: &str) -> &str {
if region.starts_with("cn-") {
"aws-cn"
} else if region.starts_with("us-iso-") {
"aws-iso"
} else if region.starts_with("us-isob-") {
"aws-iso-b"
} else if region.starts_with("us-isof-") {
"aws-iso-f"
} else if region.starts_with("eu-isoe-") {
"aws-iso-e"
} else {
"aws"
}
}
fn collect_session_policies(req: &AwsRequest, state: &IamState) -> Vec<String> {
let mut docs = Vec::new();
if let Some(inline) = req.query_params.get("Policy") {
docs.push(inline.clone());
}
for i in 1..=12 {
let key = format!("PolicyArns.member.{i}.arn");
let arn = match req.query_params.get(&key) {
Some(a) => a,
None => break,
};
match state
.policies
.get(arn.as_str())
.and_then(|p| {
p.versions
.iter()
.find(|v| v.is_default)
.or_else(|| p.versions.first())
})
.map(|v| v.document.clone())
{
Some(doc) => docs.push(doc),
None => {
tracing::debug!(
target: "fakecloud::iam::audit",
arn = %arn,
"PolicyArns entry does not resolve to a known managed policy; \
session will deny all actions covered by this entry"
);
docs.push(String::new());
}
}
}
docs
}
fn extract_access_key(req: &AwsRequest) -> Option<String> {
let auth = req.headers.get("authorization")?.to_str().ok()?;
let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
Some(info.access_key)
}
impl StsService {
fn get_caller_identity(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
if let Some(principal) = req.principal.as_ref() {
let xml = xml_responses::get_caller_identity_response(
&principal.account_id,
&principal.arn,
&principal.user_id,
&req.request_id,
);
return Ok(AwsResponse::xml(StatusCode::OK, xml));
}
let has_auth_header = req.headers.contains_key("authorization")
|| req.headers.contains_key("x-amz-security-token");
if !has_auth_header {
return Err(AwsServiceError::aws_error(
StatusCode::FORBIDDEN,
"MissingAuthenticationTokenException",
"Request is missing Authentication Token",
));
}
let accounts = self.state.read();
let account_id = accounts.default_account_id();
let partition = partition_for_region(&req.region);
let arn = format!("arn:{}:iam::{}:root", partition, account_id);
let user_id = "FKIAIOSFODNN7EXAMPLE";
let xml =
xml_responses::get_caller_identity_response(account_id, &arn, user_id, &req.request_id);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn assume_role(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter RoleArn",
)
})?;
validate_string_length("roleArn", role_arn, 20, 2048)?;
let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter RoleSessionName",
)
})?;
validate_string_length("roleSessionName", role_session_name, 2, 64)?;
if let Some(ds) = req.query_params.get("DurationSeconds") {
let v = ds.parse::<i64>().map_err(|_| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationError",
format!(
"Value '{}' at 'durationSeconds' failed to satisfy constraint: \
Member must be a valid integer",
ds
),
)
})?;
validate_range_i64("durationSeconds", v, 900, 43200)?;
}
validate_optional_string_length(
"externalId",
req.query_params.get("ExternalId").map(|s| s.as_str()),
2,
1224,
)?;
validate_optional_string_length(
"policy",
req.query_params.get("Policy").map(|s| s.as_str()),
1,
2048,
)?;
validate_optional_string_length(
"sourceIdentity",
req.query_params.get("SourceIdentity").map(|s| s.as_str()),
2,
64,
)?;
validate_optional_string_length(
"serialNumber",
req.query_params.get("SerialNumber").map(|s| s.as_str()),
9,
256,
)?;
let serial_number = req.query_params.get("SerialNumber").cloned();
validate_optional_string_length(
"tokenCode",
req.query_params.get("TokenCode").map(|s| s.as_str()),
6,
6,
)?;
let token_code = req.query_params.get("TokenCode").cloned();
let expiration_at = compute_expiration_at(req, DEFAULT_ASSUME_ROLE_DURATION)?;
let expiration = format_expiration(expiration_at);
let _mfa_serial = serial_number;
let _mfa_token = token_code;
let partition = partition_for_region(&req.region);
let creds = StsCredentials::generate();
let mut accounts = self.state.write();
let caller_state = accounts.get_or_create(&req.account_id);
let session_policies = collect_session_policies(req, caller_state);
let account_id =
extract_account_from_arn(role_arn).unwrap_or_else(|| req.account_id.clone());
let target_state = accounts.get_or_create(&account_id);
let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
if let Some(role) = target_state.roles.get(role_name).cloned() {
if role.path.starts_with("/aws-service-role/") {
let expected_service = role
.path
.trim_start_matches("/aws-service-role/")
.trim_end_matches('/');
let caller_is_service = req
.principal
.as_ref()
.map(|p| p.arn.contains(expected_service))
.unwrap_or(false);
if !caller_is_service {
return Err(AwsServiceError::aws_error(
StatusCode::FORBIDDEN,
"AccessDenied",
format!(
"User: {} is not authorized to perform: sts:AssumeRole on resource: {} because the role is a service-linked role for {}",
req.account_id, role_arn, expected_service
),
));
}
}
let trust_doc = PolicyDocument::parse(&role.assume_role_policy_document);
let caller_principal = match req.principal.as_ref() {
Some(p) => p.clone(),
None => Principal {
arn: format!("arn:aws:iam::{}:root", req.account_id),
user_id: req.account_id.clone(),
account_id: req.account_id.clone(),
principal_type: PrincipalType::Root,
source_identity: None,
tags: None,
},
};
let mfa_present = req.query_params.contains_key("SerialNumber")
&& req.query_params.contains_key("TokenCode");
let mut context = RequestContext {
aws_principal_arn: Some(caller_principal.arn.clone()),
aws_principal_account: Some(caller_principal.account_id.clone()),
aws_principal_type: Some(caller_principal.principal_type.as_str().to_string()),
aws_mfa_present: Some(mfa_present),
..Default::default()
};
if let Some(eid) = req.query_params.get("ExternalId") {
context
.service_keys
.insert("sts:externalid".to_string(), vec![eid.clone()]);
}
context.service_keys.insert(
"sts:rolesessionname".to_string(),
vec![role_session_name.clone()],
);
if let Some(src) = req.query_params.get("SourceIdentity") {
context
.service_keys
.insert("sts:sourceidentity".to_string(), vec![src.clone()]);
}
context.service_keys.insert(
"aws:sourceaccount".to_string(),
vec![caller_principal.account_id.clone()],
);
let eval_req = EvalRequest {
principal: &caller_principal,
action: "sts:AssumeRole".to_string(),
resource: role_arn.clone(),
context,
};
match evaluate_resource_policy_only(&trust_doc, &eval_req) {
Decision::Allow => {}
_ => {
return Err(AwsServiceError::aws_error(
StatusCode::FORBIDDEN,
"AccessDenied",
format!(
"User: {} is not authorized to perform: sts:AssumeRole on resource: {}",
caller_principal.arn, role_arn
),
));
}
}
}
let role_id = target_state
.roles
.get(role_name)
.map(|r| r.role_id.clone())
.unwrap_or_else(xml_responses::generate_role_id);
let assumed_role_arn = format!(
"arn:{}:sts::{}:assumed-role/{}/{}",
partition, account_id, role_name, role_session_name
);
let assumed_role_id = format!("{}:{}", role_id, role_session_name);
let target_state = accounts.get_or_create(&account_id);
target_state.credential_identities.insert(
creds.access_key_id.clone(),
CredentialIdentity {
arn: assumed_role_arn.clone(),
user_id: assumed_role_id.clone(),
account_id: account_id.clone(),
},
);
let mfa_present_for_session = req.query_params.contains_key("SerialNumber")
&& req.query_params.contains_key("TokenCode");
target_state.sts_temp_credentials.insert(
creds.access_key_id.clone(),
StsTempCredential {
access_key_id: creds.access_key_id.clone(),
secret_access_key: creds.secret_access_key.clone(),
session_token: creds.session_token.clone(),
principal_arn: assumed_role_arn,
user_id: assumed_role_id,
account_id: account_id.clone(),
expiration: expiration_at,
session_policies,
mfa_present: mfa_present_for_session,
issued_at: Utc::now(),
federated_provider: None,
},
);
let xml = xml_responses::assume_role_response(&xml_responses::AssumedRoleInfo {
role_arn,
role_session_name,
assumed_role_id: &role_id,
account_id: &account_id,
partition,
creds: &creds,
expiration: &expiration,
request_id: &req.request_id,
});
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn assume_role_with_web_identity(
&self,
req: &AwsRequest,
) -> Result<AwsResponse, AwsServiceError> {
let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter RoleArn",
)
})?;
validate_string_length("roleArn", role_arn, 20, 2048)?;
let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter RoleSessionName",
)
})?;
validate_string_length("roleSessionName", role_session_name, 2, 64)?;
let web_identity_token = req.query_params.get("WebIdentityToken").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter WebIdentityToken",
)
})?;
validate_string_length("webIdentityToken", web_identity_token, 4, 20000)?;
let web_identity_token_owned = web_identity_token.clone();
validate_optional_string_length(
"policy",
req.query_params.get("Policy").map(|s| s.as_str()),
1,
2048,
)?;
validate_optional_string_length(
"providerId",
req.query_params.get("ProviderId").map(|s| s.as_str()),
4,
2048,
)?;
if let Some(ds) = req.query_params.get("DurationSeconds") {
let v = ds.parse::<i64>().map_err(|_| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationError",
format!(
"Value '{}' at 'durationSeconds' failed to satisfy constraint: \
Member must be a valid integer",
ds
),
)
})?;
validate_range_i64("durationSeconds", v, 900, 43200)?;
}
let expiration_at = compute_expiration_at(req, DEFAULT_ASSUME_ROLE_DURATION)?;
let expiration = format_expiration(expiration_at);
let partition = partition_for_region(&req.region);
let creds = StsCredentials::generate();
let role_id = xml_responses::generate_role_id();
let mut accounts = self.state.write();
let caller_state = accounts.get_or_create(&req.account_id);
let session_policies = collect_session_policies(req, caller_state);
let account_id =
extract_account_from_arn(role_arn).unwrap_or_else(|| req.account_id.clone());
let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
let assumed_role_arn = format!(
"arn:{}:sts::{}:assumed-role/{}/{}",
partition, account_id, role_name, role_session_name
);
let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
let jwt = decode_jwt(&web_identity_token_owned);
let provider_id_param = req.query_params.get("ProviderId").cloned();
let oidc_match = jwt.as_ref().and_then(|c| c.iss.as_deref()).and_then(|iss| {
find_oidc_provider(&accounts, iss).map(|(_, p)| (iss.to_string(), p.clone()))
});
if let Some(ref claims) = jwt {
if let Some(ref iss) = claims.iss {
if oidc_match.is_none() {
return Err(AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"InvalidIdentityToken",
format!("No OpenIDConnect provider found in your account for issuer {iss}"),
));
}
if let Some((ref _iss, ref provider)) = oidc_match {
if !provider.client_id_list.is_empty() {
let any_match = claims
.aud
.iter()
.any(|aud| provider.client_id_list.iter().any(|c| c == aud));
if !any_match {
return Err(AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"InvalidIdentityToken",
format!(
"Incorrect token audience: not in client_id_list for provider {}",
provider.arn
),
));
}
}
}
}
}
let federated_provider = oidc_match
.as_ref()
.map(|(_iss, p)| p.arn.clone())
.or(provider_id_param.clone())
.unwrap_or_else(|| format!("arn:aws:iam::{}:oidc-provider/web-identity", account_id));
let target_state = accounts.get_or_create(&account_id);
if let Some(role) = target_state.roles.get(role_name).cloned() {
let trust_doc = PolicyDocument::parse(&role.assume_role_policy_document);
let caller_principal = federated_principal(&federated_provider, &account_id);
let mut context = RequestContext {
aws_principal_arn: Some(caller_principal.arn.clone()),
aws_principal_account: Some(caller_principal.account_id.clone()),
aws_principal_type: Some(caller_principal.principal_type.as_str().to_string()),
aws_federated_provider: Some(federated_provider.clone()),
..Default::default()
};
context.service_keys.insert(
"sts:rolesessionname".to_string(),
vec![role_session_name.clone()],
);
let key_prefix = oidc_match
.as_ref()
.map(|(_iss, p)| normalize_issuer(&p.url))
.or_else(|| provider_id_param.as_deref().map(normalize_issuer));
if let Some(prefix) = key_prefix {
if let Some(ref claims) = jwt {
if !claims.aud.is_empty() {
context
.service_keys
.insert(format!("{prefix}:aud"), claims.aud.clone());
}
if let Some(ref sub) = claims.sub {
context
.service_keys
.insert(format!("{prefix}:sub"), vec![sub.clone()]);
context.aws_userid = Some(sub.clone());
}
if let Some(amr) = claims.raw.get("amr").and_then(|v| v.as_array()).map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect::<Vec<_>>()
}) {
context.service_keys.insert(format!("{prefix}:amr"), amr);
}
}
}
let eval_req = EvalRequest {
principal: &caller_principal,
action: "sts:AssumeRoleWithWebIdentity".to_string(),
resource: role_arn.clone(),
context,
};
if !matches!(
evaluate_resource_policy_only(&trust_doc, &eval_req),
Decision::Allow
) {
return Err(trust_policy_denied(
"sts:AssumeRoleWithWebIdentity",
&caller_principal.arn,
role_arn,
));
}
}
let target_state = accounts.get_or_create(&account_id);
target_state.credential_identities.insert(
creds.access_key_id.clone(),
CredentialIdentity {
arn: assumed_role_arn.clone(),
user_id: assumed_role_id_str.clone(),
account_id: account_id.clone(),
},
);
let federated_provider = Some(federated_provider);
target_state.sts_temp_credentials.insert(
creds.access_key_id.clone(),
StsTempCredential {
access_key_id: creds.access_key_id.clone(),
secret_access_key: creds.secret_access_key.clone(),
session_token: creds.session_token.clone(),
principal_arn: assumed_role_arn,
user_id: assumed_role_id_str,
account_id: account_id.clone(),
expiration: expiration_at,
session_policies,
mfa_present: false,
issued_at: Utc::now(),
federated_provider,
},
);
let xml = xml_responses::assume_role_with_web_identity_response(
&xml_responses::AssumedRoleInfo {
role_arn,
role_session_name,
assumed_role_id: &role_id,
account_id: &account_id,
partition,
creds: &creds,
expiration: &expiration,
request_id: &req.request_id,
},
);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn assume_role_with_saml(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter RoleArn",
)
})?;
validate_string_length("roleArn", role_arn, 20, 2048)?;
let principal_arn = req.query_params.get("PrincipalArn").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter PrincipalArn",
)
})?;
validate_string_length("principalArn", principal_arn, 20, 2048)?;
let saml_provider_arn = principal_arn.clone();
let saml_assertion = req.query_params.get("SAMLAssertion").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter SAMLAssertion",
)
})?;
validate_string_length("sAMLAssertion", saml_assertion, 4, 100000)?;
validate_optional_string_length(
"policy",
req.query_params.get("Policy").map(|s| s.as_str()),
1,
2048,
)?;
if let Some(ds) = req.query_params.get("DurationSeconds") {
let v = ds.parse::<i64>().map_err(|_| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationError",
format!(
"Value '{}' at 'durationSeconds' failed to satisfy constraint: \
Member must be a valid integer",
ds
),
)
})?;
validate_range_i64("durationSeconds", v, 900, 43200)?;
}
let expiration_at = compute_expiration_at(req, DEFAULT_ASSUME_ROLE_DURATION)?;
let expiration = format_expiration(expiration_at);
let role_session_name =
extract_saml_session_name(saml_assertion).unwrap_or_else(|| "saml-session".to_string());
let saml_claims = extract_saml_claims(saml_assertion);
let partition = partition_for_region(&req.region);
let creds = StsCredentials::generate();
let role_id = xml_responses::generate_role_id();
let mut accounts = self.state.write();
let caller_state = accounts.get_or_create(&req.account_id);
let session_policies = collect_session_policies(req, caller_state);
let account_id =
extract_account_from_arn(role_arn).unwrap_or_else(|| req.account_id.clone());
let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
let assumed_role_arn = format!(
"arn:{}:sts::{}:assumed-role/{}/{}",
partition, account_id, role_name, &role_session_name
);
let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
if let Some(provider) = find_saml_provider(&accounts, &saml_provider_arn) {
if let Some(expected_aud) = expected_saml_audience(&provider.saml_metadata_document) {
if let Some(ref got) = saml_claims.audience {
if got != &expected_aud {
return Err(AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"InvalidIdentityToken",
format!(
"SAML assertion audience '{got}' does not match SAML provider '{}'",
provider.arn
),
));
}
}
}
}
let target_state = accounts.get_or_create(&account_id);
if let Some(role) = target_state.roles.get(role_name).cloned() {
let trust_doc = PolicyDocument::parse(&role.assume_role_policy_document);
let caller_principal = federated_principal(&saml_provider_arn, &account_id);
let mut context = RequestContext {
aws_principal_arn: Some(caller_principal.arn.clone()),
aws_principal_account: Some(caller_principal.account_id.clone()),
aws_principal_type: Some(caller_principal.principal_type.as_str().to_string()),
aws_federated_provider: Some(saml_provider_arn.clone()),
..Default::default()
};
if let Some(ref aud) = saml_claims.audience {
context
.service_keys
.insert("saml:aud".to_string(), vec![aud.clone()]);
}
if let Some(ref iss) = saml_claims.issuer {
context
.service_keys
.insert("saml:iss".to_string(), vec![iss.clone()]);
}
context.service_keys.insert(
"sts:rolesessionname".to_string(),
vec![role_session_name.clone()],
);
let eval_req = EvalRequest {
principal: &caller_principal,
action: "sts:AssumeRoleWithSAML".to_string(),
resource: role_arn.clone(),
context,
};
if !matches!(
evaluate_resource_policy_only(&trust_doc, &eval_req),
Decision::Allow
) {
return Err(trust_policy_denied(
"sts:AssumeRoleWithSAML",
&caller_principal.arn,
role_arn,
));
}
}
let target_state = accounts.get_or_create(&account_id);
target_state.credential_identities.insert(
creds.access_key_id.clone(),
CredentialIdentity {
arn: assumed_role_arn.clone(),
user_id: assumed_role_id_str.clone(),
account_id: account_id.clone(),
},
);
let federated_provider = Some(saml_provider_arn);
target_state.sts_temp_credentials.insert(
creds.access_key_id.clone(),
StsTempCredential {
access_key_id: creds.access_key_id.clone(),
secret_access_key: creds.secret_access_key.clone(),
session_token: creds.session_token.clone(),
principal_arn: assumed_role_arn,
user_id: assumed_role_id_str,
account_id: account_id.clone(),
expiration: expiration_at,
session_policies,
mfa_present: false,
issued_at: Utc::now(),
federated_provider,
},
);
let xml = xml_responses::assume_role_with_saml_response(&xml_responses::AssumedRoleInfo {
role_arn,
role_session_name: &role_session_name,
assumed_role_id: &role_id,
account_id: &account_id,
partition,
creds: &creds,
expiration: &expiration,
request_id: &req.request_id,
});
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn get_session_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
if let Some(ds) = req.query_params.get("DurationSeconds") {
let v = ds.parse::<i64>().map_err(|_| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationError",
format!(
"Value '{}' at 'durationSeconds' failed to satisfy constraint: \
Member must be a valid integer",
ds
),
)
})?;
validate_range_i64("durationSeconds", v, 900, 129600)?;
}
validate_optional_string_length(
"serialNumber",
req.query_params.get("SerialNumber").map(|s| s.as_str()),
9,
256,
)?;
let _serial_number = req.query_params.get("SerialNumber").cloned();
validate_optional_string_length(
"tokenCode",
req.query_params.get("TokenCode").map(|s| s.as_str()),
6,
6,
)?;
let _token_code = req.query_params.get("TokenCode").cloned();
let expiration_at = compute_expiration_at(req, DEFAULT_SESSION_TOKEN_DURATION)?;
let expiration = format_expiration(expiration_at);
let partition = partition_for_region(&req.region);
let mut accounts = self.state.write();
let state = accounts.get_or_create(&req.account_id);
let (principal_arn, user_id, account_id) =
if let Some(akid) = extract_access_key(req).as_deref() {
if let Some(lookup) = state.credential_secret_readonly(akid) {
(lookup.principal_arn, lookup.user_id, lookup.account_id)
} else {
(
format!("arn:{}:iam::{}:root", partition, state.account_id),
state.account_id.clone(),
state.account_id.clone(),
)
}
} else {
(
format!("arn:{}:iam::{}:root", partition, state.account_id),
state.account_id.clone(),
state.account_id.clone(),
)
};
let creds = StsCredentials::generate();
state.credential_identities.insert(
creds.access_key_id.clone(),
CredentialIdentity {
arn: principal_arn.clone(),
user_id: user_id.clone(),
account_id: account_id.clone(),
},
);
let mfa_present_for_session = req.query_params.contains_key("SerialNumber")
&& req.query_params.contains_key("TokenCode");
state.sts_temp_credentials.insert(
creds.access_key_id.clone(),
StsTempCredential {
access_key_id: creds.access_key_id.clone(),
secret_access_key: creds.secret_access_key.clone(),
session_token: creds.session_token.clone(),
principal_arn,
user_id,
account_id,
expiration: expiration_at,
session_policies: Vec::new(),
mfa_present: mfa_present_for_session,
issued_at: Utc::now(),
federated_provider: None,
},
);
let xml = xml_responses::get_session_token_response(&creds, &expiration, &req.request_id);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn get_federation_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let name = req.query_params.get("Name").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter Name",
)
})?;
validate_string_length("name", name, 2, 32)?;
if let Some(ds) = req.query_params.get("DurationSeconds") {
let v = ds.parse::<i64>().map_err(|_| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationError",
format!(
"Value '{}' at 'durationSeconds' failed to satisfy constraint: \
Member must be a valid integer",
ds
),
)
})?;
validate_range_i64("durationSeconds", v, 900, 129600)?;
}
validate_optional_string_length(
"policy",
req.query_params.get("Policy").map(|s| s.as_str()),
1,
2048,
)?;
let policy = req.query_params.get("Policy").cloned();
let expiration_at = compute_expiration_at(req, DEFAULT_FEDERATION_TOKEN_DURATION)?;
let expiration = format_expiration(expiration_at);
let partition = partition_for_region(&req.region);
let creds = StsCredentials::generate();
let mut accounts = self.state.write();
let state = accounts.get_or_create(&req.account_id);
let session_policies = collect_session_policies(req, state);
let account_id = state.account_id.clone();
let federated_user_arn = format!(
"arn:{}:sts::{}:federated-user/{}",
partition, account_id, name
);
let federated_user_id = format!("{}:{}", account_id, name);
state.credential_identities.insert(
creds.access_key_id.clone(),
CredentialIdentity {
arn: federated_user_arn.clone(),
user_id: federated_user_id.clone(),
account_id: account_id.clone(),
},
);
state.sts_temp_credentials.insert(
creds.access_key_id.clone(),
StsTempCredential {
access_key_id: creds.access_key_id.clone(),
secret_access_key: creds.secret_access_key.clone(),
session_token: creds.session_token.clone(),
principal_arn: federated_user_arn,
user_id: federated_user_id,
account_id: account_id.clone(),
expiration: expiration_at,
session_policies,
mfa_present: false,
issued_at: Utc::now(),
federated_provider: None,
},
);
let xml = xml_responses::get_federation_token_response(
&creds,
name,
&account_id,
partition,
&expiration,
policy.as_deref(),
&req.request_id,
);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn decode_authorization_message(
&self,
req: &AwsRequest,
) -> Result<AwsResponse, AwsServiceError> {
let encoded_message = req.query_params.get("EncodedMessage").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter EncodedMessage",
)
})?;
validate_string_length("encodedMessage", encoded_message, 1, 10240)?;
let decoded_message =
crate::auth_message::decode_message(encoded_message).map_err(|why| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"InvalidAuthorizationMessageException",
why,
)
})?;
let xml =
xml_responses::decode_authorization_message_response(&decoded_message, &req.request_id);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn get_access_key_info(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let access_key_id = req.query_params.get("AccessKeyId").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter AccessKeyId",
)
})?;
validate_string_length("accessKeyId", access_key_id, 16, 128)?;
let accounts = self.state.read();
let mut resolved_account_id = None;
for (acct_id, acct_state) in accounts.iter() {
if acct_state
.access_keys
.values()
.flatten()
.any(|k| k.access_key_id == *access_key_id)
{
resolved_account_id = Some(acct_id.to_string());
break;
}
if let Some(ci) = acct_state.credential_identities.get(access_key_id.as_str()) {
resolved_account_id = Some(ci.account_id.clone());
break;
}
}
let account_id =
resolved_account_id.unwrap_or_else(|| accounts.default_account_id().to_string());
let xml = xml_responses::get_access_key_info_response(&account_id, &req.request_id);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn assume_root(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let target_principal = req.query_params.get("TargetPrincipal").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter TargetPrincipal",
)
})?;
let task_policy_arn = req
.query_params
.get("TaskPolicyArn.arn")
.or_else(|| req.query_params.get("TaskPolicyArn"))
.ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter TaskPolicyArn",
)
})?;
validate_string_length("taskPolicyArn", task_policy_arn, 20, 2048)?;
if let Some(ds) = req.query_params.get("DurationSeconds") {
let v = ds.parse::<i64>().map_err(|_| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationError",
format!(
"Value '{}' at 'durationSeconds' failed to satisfy constraint: \
Member must be a valid integer",
ds
),
)
})?;
validate_range_i64("durationSeconds", v, 0, 900)?;
}
let partition = partition_for_region(&req.region);
let (target_account, target_arn) = if target_principal.starts_with("arn:") {
let acct = extract_account_from_arn(target_principal).ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationError",
"TargetPrincipal ARN is malformed",
)
})?;
(acct, target_principal.to_string())
} else if target_principal.len() == 12
&& target_principal.chars().all(|c| c.is_ascii_digit())
{
(
target_principal.to_string(),
format!("arn:{}:iam::{}:root", partition, target_principal),
)
} else {
return Err(AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationError",
"TargetPrincipal must be a member account ID or root ARN",
));
};
let expiration_at = compute_expiration_at(req, 900)?;
let expiration = format_expiration(expiration_at);
let creds = StsCredentials::generate();
let mut accounts = self.state.write();
let state = accounts.get_or_create(&target_account);
state.credential_identities.insert(
creds.access_key_id.clone(),
CredentialIdentity {
arn: target_arn.clone(),
user_id: target_account.clone(),
account_id: target_account.clone(),
},
);
state.sts_temp_credentials.insert(
creds.access_key_id.clone(),
StsTempCredential {
access_key_id: creds.access_key_id.clone(),
secret_access_key: creds.secret_access_key.clone(),
session_token: creds.session_token.clone(),
principal_arn: target_arn.clone(),
user_id: target_account.clone(),
account_id: target_account.clone(),
expiration: expiration_at,
session_policies: Vec::new(),
mfa_present: false,
issued_at: Utc::now(),
federated_provider: None,
},
);
let source_identity = req.query_params.get("SourceIdentity").map(|s| s.as_str());
let xml = xml_responses::assume_root_response(
&creds,
&expiration,
source_identity,
&req.request_id,
);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
}
fn extract_account_from_arn(arn: &str) -> Option<String> {
let parts: Vec<&str> = arn.split(':').collect();
if parts.len() >= 5 && !parts[4].is_empty() {
Some(parts[4].to_string())
} else {
None
}
}
#[derive(Debug, Clone, Default)]
struct SamlClaims {
issuer: Option<String>,
audience: Option<String>,
}
fn extract_saml_claims(saml_b64: &str) -> SamlClaims {
use base64::Engine;
let mut claims = SamlClaims::default();
let decoded = match base64::engine::general_purpose::STANDARD.decode(saml_b64) {
Ok(b) => b,
Err(_) => return claims,
};
let xml_str = match String::from_utf8(decoded) {
Ok(s) => s,
Err(_) => return claims,
};
claims.issuer = extract_xml_text_after(&xml_str, "Issuer");
claims.audience = extract_xml_text_after(&xml_str, "Audience");
claims
}
fn extract_xml_text_after(xml: &str, local_name: &str) -> Option<String> {
let mut search_from = 0;
while let Some(idx) = xml[search_from..].find('<') {
let abs = search_from + idx;
let after_lt = &xml[abs + 1..];
let tag_start = after_lt
.split_once(':')
.map(|(_pfx, rest)| rest)
.unwrap_or(after_lt);
if let Some(after_name) = tag_start.strip_prefix(local_name) {
let valid_terminator = after_name
.chars()
.next()
.map(|c| c == '>' || c == ' ' || c == '/' || c == '\t' || c == '\n')
.unwrap_or(false);
if valid_terminator {
let gt_pos = after_lt.find('>')?;
let content_start = abs + 1 + gt_pos + 1;
let next_lt = xml[content_start..].find('<')?;
let value = xml[content_start..content_start + next_lt].trim();
if !value.is_empty() {
return Some(value.to_string());
}
}
}
search_from = abs + 1;
}
None
}
#[derive(Debug, Clone, Default)]
struct JwtClaims {
iss: Option<String>,
aud: Vec<String>,
sub: Option<String>,
raw: serde_json::Map<String, serde_json::Value>,
}
fn decode_jwt(token: &str) -> Option<JwtClaims> {
use base64::Engine;
let segments: Vec<&str> = token.split('.').collect();
if segments.len() != 3 {
return None;
}
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(segments[1])
.or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(segments[1]))
.ok()?;
let json: serde_json::Value = serde_json::from_slice(&payload).ok()?;
let map = json.as_object()?.clone();
let str_field = |k: &str| map.get(k).and_then(|v| v.as_str()).map(|s| s.to_string());
let aud = match map.get("aud") {
Some(serde_json::Value::String(s)) => vec![s.clone()],
Some(serde_json::Value::Array(arr)) => arr
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect(),
_ => Vec::new(),
};
Some(JwtClaims {
iss: str_field("iss"),
aud,
sub: str_field("sub"),
raw: map,
})
}
fn normalize_issuer(value: &str) -> String {
let no_scheme = value
.strip_prefix("https://")
.or_else(|| value.strip_prefix("http://"))
.unwrap_or(value);
no_scheme.trim_end_matches('/').to_string()
}
fn find_oidc_provider<'a>(
accounts: &'a fakecloud_core::multi_account::MultiAccountState<IamState>,
issuer: &str,
) -> Option<(&'a str, &'a crate::state::OidcProvider)> {
let normalized = normalize_issuer(issuer);
for (acct_id, state) in accounts.iter() {
for provider in state.oidc_providers.values() {
if normalize_issuer(&provider.url) == normalized {
return Some((acct_id, provider));
}
}
}
None
}
fn find_saml_provider<'a>(
accounts: &'a fakecloud_core::multi_account::MultiAccountState<IamState>,
arn: &str,
) -> Option<&'a crate::state::SamlProvider> {
for (_acct_id, state) in accounts.iter() {
if let Some(provider) = state.saml_providers.get(arn) {
return Some(provider);
}
}
None
}
fn expected_saml_audience(metadata: &str) -> Option<String> {
let needle = "entityID=";
let pos = metadata.find(needle)?;
let after = &metadata[pos + needle.len()..];
let quote = after.chars().next()?;
if quote != '"' && quote != '\'' {
return None;
}
let rest = &after[1..];
let end = rest.find(quote)?;
let value = rest[..end].trim();
if value.is_empty() {
None
} else {
Some(value.to_string())
}
}
fn federated_principal(provider_arn: &str, account_id: &str) -> Principal {
Principal {
arn: provider_arn.to_string(),
user_id: provider_arn.to_string(),
account_id: account_id.to_string(),
principal_type: PrincipalType::FederatedUser,
source_identity: None,
tags: None,
}
}
fn trust_policy_denied(action: &str, caller_arn: &str, role_arn: &str) -> AwsServiceError {
AwsServiceError::aws_error(
StatusCode::FORBIDDEN,
"AccessDenied",
format!(
"User: {} is not authorized to perform: {} on resource: {}",
caller_arn, action, role_arn
),
)
}
fn extract_saml_session_name(saml_b64: &str) -> Option<String> {
use base64::Engine;
let decoded = base64::engine::general_purpose::STANDARD
.decode(saml_b64)
.ok()?;
let xml_str = String::from_utf8(decoded).ok()?;
let role_session_attr = "https://aws.amazon.com/SAML/Attributes/RoleSessionName";
let pos = xml_str.find(role_session_attr)?;
let after = &xml_str[pos..];
let av_start = after.find("AttributeValue")?;
let after_av = &after[av_start..];
let gt_pos = after_av.find('>')?;
let value_start = &after_av[gt_pos + 1..];
let lt_pos = value_start.find('<')?;
let value = value_start[..lt_pos].trim();
if value.is_empty() {
None
} else {
Some(value.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partition_for_region() {
assert_eq!(partition_for_region("us-east-1"), "aws");
assert_eq!(partition_for_region("eu-west-1"), "aws");
assert_eq!(partition_for_region("cn-north-1"), "aws-cn");
assert_eq!(partition_for_region("cn-northwest-1"), "aws-cn");
assert_eq!(partition_for_region("us-isob-east-1"), "aws-iso-b");
assert_eq!(partition_for_region("us-iso-east-1"), "aws-iso");
}
#[test]
fn test_extract_account_from_arn() {
assert_eq!(
extract_account_from_arn("arn:aws:iam::123456789012:role/test"),
Some("123456789012".to_string())
);
assert_eq!(
extract_account_from_arn("arn:aws:iam::111111111111:role/test"),
Some("111111111111".to_string())
);
assert_eq!(extract_account_from_arn("invalid"), None);
}
#[test]
fn test_extract_saml_session_name() {
use base64::Engine;
let xml = r#"<?xml version="1.0"?><samlp:Response><Assertion><AttributeStatement><Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><AttributeValue>testuser</AttributeValue></Attribute></AttributeStatement></Assertion></samlp:Response>"#;
let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
assert_eq!(
extract_saml_session_name(&encoded),
Some("testuser".to_string())
);
}
#[test]
fn test_extract_saml_session_name_with_namespace() {
use base64::Engine;
let xml = r#"<?xml version="1.0"?><samlp:Response><saml:Assertion><saml:AttributeStatement><saml:Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><saml:AttributeValue>testuser</saml:AttributeValue></saml:Attribute></saml:AttributeStatement></saml:Assertion></samlp:Response>"#;
let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
assert_eq!(
extract_saml_session_name(&encoded),
Some("testuser".to_string())
);
}
#[test]
fn test_session_token_format() {
let token = xml_responses::generate_session_token();
assert_eq!(token.len(), 356);
assert!(token.starts_with("FQoGZXIvYXdzE"));
}
#[test]
fn test_access_key_id_format() {
let key = xml_responses::generate_access_key_id();
assert_eq!(key.len(), 20);
assert!(key.starts_with("FSIA"));
}
#[test]
fn test_secret_access_key_format() {
let key = xml_responses::generate_secret_access_key();
assert_eq!(key.len(), 40);
}
#[test]
fn test_role_id_format() {
let id = xml_responses::generate_role_id();
assert_eq!(id.len(), 21);
assert!(id.starts_with("AROA"));
}
#[test]
fn test_decode_authorization_message() {
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
let state: SharedIamState = Arc::new(RwLock::new(
fakecloud_core::multi_account::MultiAccountState::new("123456789012", "us-east-1", ""),
));
let service = StsService::new(state);
let token = crate::auth_message::encode_deny(
true,
Some("s3:GetObject"),
Some("arn:aws:iam::123456789012:user/alice"),
vec![serde_json::json!({"sourcePolicyId": "deny-bucket-foo"})],
None,
);
let mut params = HashMap::new();
params.insert("EncodedMessage".to_string(), token);
let req = make_test_request(params);
let resp = service.decode_authorization_message(&req).unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
assert!(body.contains("DecodedMessage"));
assert!(body.contains("explicitDeny"));
assert!(body.contains("s3:GetObject"));
assert!(body.contains("deny-bucket-foo"));
}
#[test]
fn test_decode_authorization_message_rejects_invalid_token() {
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
let state: SharedIamState = Arc::new(RwLock::new(
fakecloud_core::multi_account::MultiAccountState::new("123456789012", "us-east-1", ""),
));
let service = StsService::new(state);
let mut params = HashMap::new();
params.insert("EncodedMessage".to_string(), "not-a-real-token".to_string());
let req = make_test_request(params);
let err = match service.decode_authorization_message(&req) {
Err(e) => e,
Ok(_) => panic!("expected InvalidAuthorizationMessageException"),
};
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
let msg = format!("{:?}", err);
assert!(msg.contains("InvalidAuthorizationMessageException"));
}
#[test]
fn test_decode_authorization_message_missing_param() {
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
let state: SharedIamState = Arc::new(RwLock::new(
fakecloud_core::multi_account::MultiAccountState::new("123456789012", "us-east-1", ""),
));
let service = StsService::new(state);
let req = make_test_request(HashMap::new());
let result = service.decode_authorization_message(&req);
assert!(result.is_err());
let err = result.err().unwrap();
let msg = format!("{:?}", err);
assert!(msg.contains("EncodedMessage"));
}
fn make_test_request(params: std::collections::HashMap<String, String>) -> AwsRequest {
AwsRequest {
service: "sts".into(),
action: "Test".into(),
region: "us-east-1".into(),
account_id: "123456789012".into(),
request_id: "test".into(),
headers: http::HeaderMap::new(),
query_params: params,
body: Default::default(),
body_stream: parking_lot::Mutex::new(None),
path_segments: vec![],
raw_path: "/".into(),
raw_query: String::new(),
method: http::Method::POST,
is_query_protocol: true,
access_key_id: None,
principal: None,
}
}
fn parse_expiration(s: &str) -> chrono::DateTime<Utc> {
chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ")
.expect("valid timestamp")
.and_utc()
}
#[test]
fn test_compute_expiration_with_duration() {
use std::collections::HashMap;
let mut params = HashMap::new();
params.insert("DurationSeconds".to_string(), "1800".to_string());
let req = make_test_request(params);
let now = Utc::now();
let exp_str = compute_expiration(&req, 3600).unwrap();
let exp_utc = parse_expiration(&exp_str);
let diff = (exp_utc - now).num_seconds();
assert!(
(1798..=1802).contains(&diff),
"expected ~1800s duration, got {diff}s"
);
}
#[test]
fn test_compute_expiration_default() {
use std::collections::HashMap;
let req = make_test_request(HashMap::new());
let now = Utc::now();
let exp_str = compute_expiration(&req, 43200).unwrap();
let exp_utc = parse_expiration(&exp_str);
let diff = (exp_utc - now).num_seconds();
assert!(
(43198..=43202).contains(&diff),
"expected ~43200s duration, got {diff}s"
);
}
#[test]
fn test_compute_expiration_uses_provided_not_default() {
use std::collections::HashMap;
let mut params = HashMap::new();
params.insert("DurationSeconds".to_string(), "900".to_string());
let req = make_test_request(params);
let before = Utc::now();
let exp_str = compute_expiration(&req, 43200).unwrap();
let exp_utc = parse_expiration(&exp_str);
let expected = before + chrono::Duration::seconds(900);
let diff = (exp_utc - expected).num_seconds().abs();
assert!(
diff <= 2,
"expected ~900s duration, got diff={diff}s from expected"
);
}
fn make_sts_service() -> (StsService, SharedIamState) {
use parking_lot::RwLock;
use std::sync::Arc;
let state: SharedIamState = Arc::new(RwLock::new(
fakecloud_core::multi_account::MultiAccountState::new("123456789012", "us-east-1", ""),
));
let sts = StsService::new(state.clone());
(sts, state)
}
fn sts_request(action: &str, params: Vec<(&str, &str)>) -> AwsRequest {
let mut qp = std::collections::HashMap::new();
qp.insert("Action".to_string(), action.to_string());
for (k, v) in params {
qp.insert(k.to_string(), v.to_string());
}
let mut req = make_test_request(qp);
req.action = action.to_string();
req
}
fn create_role_in_state(state: &SharedIamState, name: &str) -> String {
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole"}]}"#;
create_role_in_state_with_trust(state, name, trust)
}
fn create_role_in_state_with_trust(
state: &SharedIamState,
name: &str,
trust_policy: &str,
) -> String {
let arn = fakecloud_aws::arn::Arn::global("iam", "123456789012", &format!("role/{name}"))
.to_string();
let mut accounts = state.write();
let s = accounts.get_or_create("123456789012");
s.roles.insert(
name.to_string(),
crate::state::IamRole {
role_name: name.to_string(),
role_id: format!("AROA{}", &uuid::Uuid::new_v4().to_string()[..17]),
arn: arn.clone(),
path: "/".to_string(),
assume_role_policy_document: trust_policy.to_string(),
created_at: Utc::now(),
description: None,
max_session_duration: 3600,
tags: Vec::new(),
permissions_boundary: None,
},
);
arn
}
#[tokio::test]
async fn get_caller_identity() {
let (svc, _) = make_sts_service();
let mut req = sts_request("GetCallerIdentity", vec![]);
req.headers.insert(
http::header::AUTHORIZATION,
http::HeaderValue::from_static("AWS4-HMAC-SHA256 Credential=test/test"),
);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
assert!(body.contains("<Account>123456789012</Account>"));
assert!(body.contains("<Arn>"));
}
#[tokio::test]
async fn get_caller_identity_rejects_unauthenticated_request() {
let (svc, _) = make_sts_service();
let req = sts_request("GetCallerIdentity", vec![]);
let err = match svc.handle(req).await {
Err(e) => e,
Ok(_) => panic!("expected MissingAuthenticationTokenException"),
};
assert_eq!(err.status(), StatusCode::FORBIDDEN);
assert!(format!("{:?}", err).contains("MissingAuthenticationTokenException"));
}
#[tokio::test]
async fn assume_role_basic() {
let (svc, state) = make_sts_service();
let role_arn = create_role_in_state(&state, "test-role");
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "test-session")],
);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
assert!(body.contains("<AccessKeyId>"));
assert!(body.contains("<SecretAccessKey>"));
assert!(body.contains("<SessionToken>"));
}
#[tokio::test]
async fn assume_role_not_found() {
let (svc, _) = make_sts_service();
let req = sts_request(
"AssumeRole",
vec![
("RoleArn", "arn:aws:iam::123456789012:role/nonexistent"),
("RoleSessionName", "s"),
],
);
assert!(svc.handle(req).await.is_err());
}
#[tokio::test]
async fn assume_role_missing_session_name() {
let (svc, _) = make_sts_service();
let req = sts_request(
"AssumeRole",
vec![("RoleArn", "arn:aws:iam::123456789012:role/r")],
);
assert!(svc.handle(req).await.is_err());
}
#[tokio::test]
async fn assume_role_with_web_identity() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRoleWithWebIdentity"}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "web-role", trust);
let req = sts_request(
"AssumeRoleWithWebIdentity",
vec![
("RoleArn", &role_arn),
("RoleSessionName", "web-session"),
("WebIdentityToken", "fake-jwt-token"),
],
);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
assert!(body.contains("<AccessKeyId>"));
}
#[tokio::test]
async fn get_session_token() {
let (svc, _) = make_sts_service();
let req = sts_request("GetSessionToken", vec![]);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
assert!(body.contains("<AccessKeyId>"));
assert!(body.contains("<SessionToken>"));
}
#[tokio::test]
async fn get_session_token_with_duration() {
let (svc, _) = make_sts_service();
let req = sts_request("GetSessionToken", vec![("DurationSeconds", "1800")]);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
assert!(body.contains("<Expiration>"));
}
#[tokio::test]
async fn get_federation_token() {
let (svc, _) = make_sts_service();
let req = sts_request("GetFederationToken", vec![("Name", "feduser")]);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
assert!(body.contains("<AccessKeyId>"));
assert!(body.contains("<FederatedUserId>"));
}
#[tokio::test]
async fn get_access_key_info() {
let (svc, _) = make_sts_service();
let req = sts_request(
"GetAccessKeyInfo",
vec![("AccessKeyId", "AKIAIOSFODNN7EXAMPLE")],
);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
assert!(body.contains("<Account>"));
}
#[tokio::test]
async fn assume_role_rejects_when_external_id_missing() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole","Condition":{"StringEquals":{"sts:ExternalId":"secret-handshake"}}}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "third-party", trust);
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "sess")],
);
let err = match svc.handle(req).await {
Err(e) => e,
Ok(_) => panic!("expected AccessDenied when ExternalId missing"),
};
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn assume_role_rejects_when_external_id_mismatches() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole","Condition":{"StringEquals":{"sts:ExternalId":"secret-handshake"}}}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "third-party", trust);
let req = sts_request(
"AssumeRole",
vec![
("RoleArn", &role_arn),
("RoleSessionName", "sess"),
("ExternalId", "wrongguess"),
],
);
let err = match svc.handle(req).await {
Err(e) => e,
Ok(_) => panic!("expected AccessDenied when ExternalId mismatches"),
};
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn assume_role_succeeds_when_external_id_matches() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole","Condition":{"StringEquals":{"sts:ExternalId":"secret-handshake"}}}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "third-party", trust);
let req = sts_request(
"AssumeRole",
vec![
("RoleArn", &role_arn),
("RoleSessionName", "sess"),
("ExternalId", "secret-handshake"),
],
);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
assert!(body.contains("<AccessKeyId>"));
}
#[tokio::test]
async fn assume_role_proceeds_when_no_external_id_required() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole"}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "open-role", trust);
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "sess")],
);
svc.handle(req).await.unwrap();
}
#[tokio::test]
async fn assume_role_rejects_when_trust_policy_has_no_statements() {
let (svc, state) = make_sts_service();
let role_arn = create_role_in_state_with_trust(&state, "no-trust", r#"{"Statement":[]}"#);
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "sess")],
);
let err = match svc.handle(req).await {
Err(e) => e,
Ok(_) => panic!("expected AccessDenied"),
};
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn assume_role_rejects_when_trust_policy_excludes_caller() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"Service":"ec2.amazonaws.com"},"Action":"sts:AssumeRole"}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "ec2-only", trust);
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "sess")],
);
let err = match svc.handle(req).await {
Err(e) => e,
Ok(_) => panic!("expected AccessDenied"),
};
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn assume_role_rejects_when_trust_policy_explicitly_denies() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole"},{"Effect":"Deny","Principal":{"AWS":"*"},"Action":"sts:AssumeRole"}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "deny-wins", trust);
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "sess")],
);
let err = match svc.handle(req).await {
Err(e) => e,
Ok(_) => panic!("expected AccessDenied"),
};
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn assume_role_allowed_by_trust_policy_with_principal_match() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"123456789012"},"Action":"sts:AssumeRole"}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "named", trust);
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "sess")],
);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
assert!(body.contains("<AccessKeyId>"), "{body}");
}
#[tokio::test]
async fn assume_role_blocked_when_principal_not_in_trust_policy() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"arn:aws:iam::999999999999:root"},"Action":"sts:AssumeRole"}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "other-account", trust);
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "sess")],
);
let err = match svc.handle(req).await {
Err(e) => e,
Ok(_) => panic!("expected AccessDenied when caller account not in trust policy"),
};
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn assume_role_blocked_when_external_id_required_but_missing() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole","Condition":{"StringEquals":{"sts:ExternalId":"hello"}}}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "ext-required", trust);
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "sess")],
);
let err = match svc.handle(req).await {
Err(e) => e,
Ok(_) => panic!("expected AccessDenied when ExternalId required but missing"),
};
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn assume_role_succeeds_with_correct_external_id() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole","Condition":{"StringEquals":{"sts:ExternalId":"hello"}}}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "ext-ok", trust);
let req = sts_request(
"AssumeRole",
vec![
("RoleArn", &role_arn),
("RoleSessionName", "sess"),
("ExternalId", "hello"),
],
);
svc.handle(req).await.unwrap();
}
#[tokio::test]
async fn assume_role_blocked_when_mfa_required_but_not_present() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole","Condition":{"Bool":{"aws:MultiFactorAuthPresent":"true"}}}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "mfa-required", trust);
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "sess")],
);
let err = match svc.handle(req).await {
Err(e) => e,
Ok(_) => panic!("expected AccessDenied when MFA required but not supplied"),
};
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn assume_role_succeeds_with_mfa_supplied() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole","Condition":{"Bool":{"aws:MultiFactorAuthPresent":"true"}}}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "mfa-ok", trust);
let req = sts_request(
"AssumeRole",
vec![
("RoleArn", &role_arn),
("RoleSessionName", "sess"),
("SerialNumber", "arn:aws:iam::123456789012:mfa/alice"),
("TokenCode", "123456"),
],
);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
assert!(body.contains("<AccessKeyId>"), "{body}");
let states = state.read();
let s = states.get("123456789012").unwrap();
let any_mfa = s.sts_temp_credentials.values().any(|c| c.mfa_present);
assert!(
any_mfa,
"expected at least one minted credential with mfa_present=true"
);
}
#[tokio::test]
async fn assume_role_with_mfa_resolved_credential_drives_iam_evaluator() {
use crate::credential_resolver::IamCredentialResolver;
use crate::evaluator::{
evaluate as eval_policies, EvalRequest, PolicyDocument, RequestContext,
};
use fakecloud_core::auth::{ConditionContext, CredentialResolver};
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole"}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "mfa-e2e", trust);
let req = sts_request(
"AssumeRole",
vec![
("RoleArn", &role_arn),
("RoleSessionName", "ops"),
("SerialNumber", "arn:aws:iam::123456789012:mfa/alice"),
("TokenCode", "654321"),
],
);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
let access_key_id = body
.split("<AccessKeyId>")
.nth(1)
.and_then(|s| s.split("</AccessKeyId>").next())
.expect("response should contain AccessKeyId")
.to_string();
let resolver = IamCredentialResolver::new(state.clone());
let resolved = resolver
.resolve(&access_key_id)
.expect("issued credential must resolve through the resolver");
assert!(
resolved.mfa_present,
"F3: MFA flag must survive the resolver hop"
);
assert!(
resolved.token_issued_at.is_some(),
"F3: token_issued_at must be populated for STS sessions"
);
let mut ctx: RequestContext = ConditionContext {
aws_principal_arn: Some(resolved.principal.arn.clone()),
aws_principal_account: Some(resolved.principal.account_id.clone()),
aws_userid: Some(resolved.principal.user_id.clone()),
aws_mfa_present: Some(resolved.mfa_present),
aws_token_issue_time: resolved.token_issued_at,
aws_federated_provider: resolved.federated_provider.clone(),
..Default::default()
};
if resolved.mfa_present {
if let Some(issued) = resolved.token_issued_at {
ctx.aws_mfa_age_seconds = Some(
Utc::now()
.signed_duration_since(issued)
.num_seconds()
.max(0),
);
}
}
let policy = PolicyDocument::parse(
r#"{"Version":"2012-10-17","Statement":[{
"Effect":"Allow",
"Action":"s3:GetObject",
"Resource":"*",
"Condition":{"Bool":{"aws:MultiFactorAuthPresent":"true"}}
}]}"#,
);
let eval = EvalRequest {
principal: &resolved.principal,
action: "s3:GetObject".to_string(),
resource: "arn:aws:s3:::secrets/k".to_string(),
context: ctx,
};
let decision = eval_policies(&[policy], &eval);
assert_eq!(
decision,
crate::evaluator::Decision::Allow,
"F3: MFA-gated allow must fire when session was minted with MFA"
);
let req_no_mfa = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "no-mfa")],
);
let resp_no_mfa = svc.handle(req_no_mfa).await.unwrap();
let body_no_mfa = std::str::from_utf8(resp_no_mfa.body.expect_bytes()).unwrap();
let akid_no_mfa = body_no_mfa
.split("<AccessKeyId>")
.nth(1)
.and_then(|s| s.split("</AccessKeyId>").next())
.unwrap()
.to_string();
let resolved_no_mfa = resolver.resolve(&akid_no_mfa).unwrap();
assert!(!resolved_no_mfa.mfa_present);
let policy2 = PolicyDocument::parse(
r#"{"Version":"2012-10-17","Statement":[{
"Effect":"Allow",
"Action":"s3:GetObject",
"Resource":"*",
"Condition":{"Bool":{"aws:MultiFactorAuthPresent":"true"}}
}]}"#,
);
let ctx2 = ConditionContext {
aws_principal_arn: Some(resolved_no_mfa.principal.arn.clone()),
aws_userid: Some(resolved_no_mfa.principal.user_id.clone()),
aws_mfa_present: Some(resolved_no_mfa.mfa_present),
aws_token_issue_time: resolved_no_mfa.token_issued_at,
..Default::default()
};
let eval2 = EvalRequest {
principal: &resolved_no_mfa.principal,
action: "s3:GetObject".to_string(),
resource: "arn:aws:s3:::secrets/k".to_string(),
context: ctx2,
};
assert_eq!(
eval_policies(&[policy2], &eval2),
crate::evaluator::Decision::ImplicitDeny,
"F3: MFA-gated allow must NOT fire when session was minted without MFA"
);
}
#[tokio::test]
async fn assume_role_with_saml_populates_federated_provider() {
use crate::credential_resolver::IamCredentialResolver;
use base64::Engine;
use fakecloud_core::auth::CredentialResolver;
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRoleWithSAML"}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "saml-role", trust);
let saml_xml = r#"<?xml version="1.0"?><samlp:Response><Assertion><AttributeStatement><Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><AttributeValue>jane</AttributeValue></Attribute></AttributeStatement></Assertion></samlp:Response>"#;
let saml_b64 = base64::engine::general_purpose::STANDARD.encode(saml_xml);
let provider_arn = "arn:aws:iam::123456789012:saml-provider/idp";
let req = sts_request(
"AssumeRoleWithSAML",
vec![
("RoleArn", &role_arn),
("PrincipalArn", provider_arn),
("SAMLAssertion", &saml_b64),
],
);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
let access_key_id = body
.split("<AccessKeyId>")
.nth(1)
.and_then(|s| s.split("</AccessKeyId>").next())
.unwrap()
.to_string();
let resolver = IamCredentialResolver::new(state.clone());
let resolved = resolver.resolve(&access_key_id).unwrap();
assert_eq!(
resolved.federated_provider.as_deref(),
Some(provider_arn),
"AssumeRoleWithSAML must populate aws:FederatedProvider with the SAML provider ARN"
);
}
#[tokio::test]
async fn assume_role_with_web_identity_populates_federated_provider() {
use crate::credential_resolver::IamCredentialResolver;
use fakecloud_core::auth::CredentialResolver;
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRoleWithWebIdentity"}]}"#;
let role_arn = create_role_in_state_with_trust(&state, "oidc-role", trust);
let req = sts_request(
"AssumeRoleWithWebIdentity",
vec![
("RoleArn", &role_arn),
("RoleSessionName", "oidc-session"),
("WebIdentityToken", "fake-jwt-blob"),
("ProviderId", "accounts.google.com"),
],
);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
let access_key_id = body
.split("<AccessKeyId>")
.nth(1)
.and_then(|s| s.split("</AccessKeyId>").next())
.unwrap()
.to_string();
let resolver = IamCredentialResolver::new(state.clone());
let resolved = resolver.resolve(&access_key_id).unwrap();
assert_eq!(
resolved.federated_provider.as_deref(),
Some("accounts.google.com"),
"AssumeRoleWithWebIdentity must carry ProviderId as aws:FederatedProvider"
);
}
#[tokio::test]
async fn assume_role_userid_format_matches_aws() {
use crate::credential_resolver::IamCredentialResolver;
use fakecloud_core::auth::CredentialResolver;
let (svc, state) = make_sts_service();
let role_arn = create_role_in_state(&state, "userid-role");
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &role_arn), ("RoleSessionName", "carol")],
);
let resp = svc.handle(req).await.unwrap();
let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
let access_key_id = body
.split("<AccessKeyId>")
.nth(1)
.and_then(|s| s.split("</AccessKeyId>").next())
.unwrap()
.to_string();
let resolver = IamCredentialResolver::new(state);
let resolved = resolver.resolve(&access_key_id).unwrap();
let uid = &resolved.principal.user_id;
assert!(
uid.contains(':'),
"assumed-role userid must be `<role-id>:<RoleSessionName>`, got `{uid}`"
);
assert!(
uid.ends_with(":carol"),
"assumed-role userid must end with the RoleSessionName, got `{uid}`"
);
}
#[tokio::test]
async fn assume_service_linked_role_blocked_when_caller_not_matching_service() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"Service":"ecs.amazonaws.com"},"Action":"sts:AssumeRole"}]}"#;
let arn = fakecloud_aws::arn::Arn::global(
"iam",
"123456789012",
"role/aws-service-role/ecs.amazonaws.com/AWSServiceRoleForECS",
)
.to_string();
{
let mut accounts = state.write();
let s = accounts.get_or_create("123456789012");
s.roles.insert(
"AWSServiceRoleForECS".to_string(),
crate::state::IamRole {
role_name: "AWSServiceRoleForECS".to_string(),
role_id: "AROASLRECS".to_string(),
arn: arn.clone(),
path: "/aws-service-role/ecs.amazonaws.com/".to_string(),
assume_role_policy_document: trust.to_string(),
created_at: Utc::now(),
description: None,
max_session_duration: 3600,
tags: Vec::new(),
permissions_boundary: None,
},
);
}
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &arn), ("RoleSessionName", "sess")],
);
let err = match svc.handle(req).await {
Err(e) => e,
Ok(_) => {
panic!("expected AccessDenied for service-linked role with non-service caller")
}
};
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn service_linked_role_rejects_non_service_caller() {
let (svc, state) = make_sts_service();
let trust = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"*"},"Action":"sts:AssumeRole"}]}"#;
let arn = fakecloud_aws::arn::Arn::global(
"iam",
"123456789012",
"role/aws-service-role/elasticloadbalancing.amazonaws.com/AWSServiceRoleForELB",
)
.to_string();
{
let mut accounts = state.write();
let s = accounts.get_or_create("123456789012");
s.roles.insert(
"AWSServiceRoleForELB".to_string(),
crate::state::IamRole {
role_name: "AWSServiceRoleForELB".to_string(),
role_id: "AROASLR".to_string(),
arn: arn.clone(),
path: "/aws-service-role/elasticloadbalancing.amazonaws.com/".to_string(),
assume_role_policy_document: trust.to_string(),
created_at: Utc::now(),
description: None,
max_session_duration: 3600,
tags: Vec::new(),
permissions_boundary: None,
},
);
}
let req = sts_request(
"AssumeRole",
vec![("RoleArn", &arn), ("RoleSessionName", "sess")],
);
let err = match svc.handle(req).await {
Err(e) => e,
Ok(_) => panic!("expected AccessDenied for non-service caller"),
};
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn unsupported_sts_action() {
let (svc, _) = make_sts_service();
let req = sts_request("BogusAction", vec![]);
assert!(svc.handle(req).await.is_err());
}
#[tokio::test]
async fn assume_role_missing_role_arn_errors() {
let (svc, _) = make_sts_service();
let req = sts_request("AssumeRole", vec![("RoleSessionName", "sess")]);
assert!(svc.assume_role(&req).is_err());
}
#[tokio::test]
async fn assume_role_with_web_identity_missing_token_errors() {
let (svc, _) = make_sts_service();
let req = sts_request(
"AssumeRoleWithWebIdentity",
vec![
("RoleArn", "arn:aws:iam::123:role/r"),
("RoleSessionName", "s"),
],
);
assert!(svc.assume_role_with_web_identity(&req).is_err());
}
#[tokio::test]
async fn assume_role_with_saml_missing_assertion_errors() {
let (svc, _) = make_sts_service();
let req = sts_request(
"AssumeRoleWithSAML",
vec![
("RoleArn", "arn:aws:iam::123:role/r"),
("PrincipalArn", "arn:aws:iam::123:saml-provider/p"),
],
);
assert!(svc.assume_role_with_saml(&req).is_err());
}
#[tokio::test]
async fn get_session_token_returns_ok() {
let (svc, _) = make_sts_service();
let req = sts_request("GetSessionToken", vec![]);
let resp = svc.get_session_token(&req).unwrap();
assert_eq!(resp.status, http::StatusCode::OK);
}
#[tokio::test]
async fn get_federation_token_returns_ok() {
let (svc, _) = make_sts_service();
let req = sts_request("GetFederationToken", vec![("Name", "test-user")]);
let resp = svc.get_federation_token(&req).unwrap();
assert_eq!(resp.status, http::StatusCode::OK);
}
#[tokio::test]
async fn get_federation_token_missing_name_errors() {
let (svc, _) = make_sts_service();
let req = sts_request("GetFederationToken", vec![]);
assert!(svc.get_federation_token(&req).is_err());
}
#[tokio::test]
async fn assume_root_with_account_id_succeeds() {
let (svc, _) = make_sts_service();
let req = sts_request(
"AssumeRoot",
vec![
("TargetPrincipal", "111122223333"),
(
"TaskPolicyArn.arn",
"arn:aws:iam::aws:policy/IAMAuditRootUserCredentials",
),
],
);
let resp = svc.assume_root(&req).unwrap();
assert_eq!(resp.status, http::StatusCode::OK);
let body = String::from_utf8(resp.body.expect_bytes().to_vec()).unwrap();
assert!(body.contains("AccessKeyId"), "{body}");
}
#[tokio::test]
async fn assume_root_with_arn_succeeds() {
let (svc, _) = make_sts_service();
let req = sts_request(
"AssumeRoot",
vec![
("TargetPrincipal", "arn:aws:iam::444455556666:root"),
(
"TaskPolicyArn.arn",
"arn:aws:iam::aws:policy/IAMAuditRootUserCredentials",
),
("DurationSeconds", "600"),
("SourceIdentity", "alice"),
],
);
let resp = svc.assume_root(&req).unwrap();
let body = String::from_utf8(resp.body.expect_bytes().to_vec()).unwrap();
assert!(
body.contains("<SourceIdentity>alice</SourceIdentity>"),
"{body}"
);
}
#[tokio::test]
async fn assume_root_missing_task_policy_errors() {
let (svc, _) = make_sts_service();
let req = sts_request("AssumeRoot", vec![("TargetPrincipal", "111122223333")]);
let err = match svc.assume_root(&req) {
Err(e) => e,
Ok(_) => panic!("expected err"),
};
assert!(err.to_string().contains("TaskPolicyArn"));
}
#[tokio::test]
async fn assume_root_invalid_principal_errors() {
let (svc, _) = make_sts_service();
let req = sts_request(
"AssumeRoot",
vec![
("TargetPrincipal", "not-an-id"),
("TaskPolicyArn.arn", "arn:aws:iam::aws:policy/X"),
],
);
assert!(svc.assume_root(&req).is_err());
}
#[tokio::test]
async fn assume_root_duration_above_max_errors() {
let (svc, _) = make_sts_service();
let req = sts_request(
"AssumeRoot",
vec![
("TargetPrincipal", "111122223333"),
("TaskPolicyArn.arn", "arn:aws:iam::aws:policy/X"),
("DurationSeconds", "1800"),
],
);
assert!(svc.assume_root(&req).is_err());
}
}