use async_trait::async_trait;
use http::StatusCode;
use fakecloud_core::service::{AwsRequest, AwsResponse, AwsService, AwsServiceError};
use crate::state::{CredentialIdentity, SharedIamState};
use crate::xml_responses::{self, StsCredentials};
pub struct StsService {
state: SharedIamState,
}
impl StsService {
pub fn new(state: SharedIamState) -> Self {
Self { state }
}
}
#[async_trait]
impl AwsService for StsService {
fn service_name(&self) -> &str {
"sts"
}
async fn handle(&self, req: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
match req.action.as_str() {
"GetCallerIdentity" => self.get_caller_identity(&req),
"AssumeRole" => self.assume_role(&req),
"AssumeRoleWithWebIdentity" => self.assume_role_with_web_identity(&req),
"AssumeRoleWithSAML" => self.assume_role_with_saml(&req),
"GetSessionToken" => self.get_session_token(&req),
"GetFederationToken" => self.get_federation_token(&req),
"GetAccessKeyInfo" => self.get_access_key_info(&req),
_ => Err(AwsServiceError::action_not_implemented("sts", &req.action)),
}
}
fn supported_actions(&self) -> &[&str] {
&[
"GetCallerIdentity",
"AssumeRole",
"AssumeRoleWithWebIdentity",
"AssumeRoleWithSAML",
"GetSessionToken",
"GetFederationToken",
"GetAccessKeyInfo",
]
}
}
fn partition_for_region(region: &str) -> &str {
if region.starts_with("cn-") {
"aws-cn"
} else if region.starts_with("us-iso-") {
"aws-iso"
} else if region.starts_with("us-isob-") {
"aws-iso-b"
} else if region.starts_with("us-isof-") {
"aws-iso-f"
} else if region.starts_with("eu-isoe-") {
"aws-iso-e"
} else {
"aws"
}
}
const MAX_ROLE_SESSION_NAME_LENGTH: usize = 64;
const MAX_FEDERATION_TOKEN_POLICY_LENGTH: usize = 2048;
fn extract_access_key(req: &AwsRequest) -> Option<String> {
let auth = req.headers.get("authorization")?.to_str().ok()?;
let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
Some(info.access_key)
}
impl StsService {
fn get_caller_identity(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let state = self.state.read();
let partition = partition_for_region(&req.region);
if let Some(access_key) = extract_access_key(req) {
if let Some(identity) = state.credential_identities.get(&access_key) {
let xml = xml_responses::get_caller_identity_response(
&identity.account_id,
&identity.arn,
&identity.user_id,
&req.request_id,
);
return Ok(AwsResponse::xml(StatusCode::OK, xml));
}
for keys in state.access_keys.values() {
for key in keys {
if key.access_key_id == access_key {
if let Some(user) = state.users.get(&key.user_name) {
let xml = xml_responses::get_caller_identity_response(
&state.account_id,
&user.arn,
&user.user_id,
&req.request_id,
);
return Ok(AwsResponse::xml(StatusCode::OK, xml));
}
}
}
}
}
let arn = format!("arn:{}:iam::{}:root", partition, state.account_id);
let user_id = "FKIAIOSFODNN7EXAMPLE";
let xml = xml_responses::get_caller_identity_response(
&state.account_id,
&arn,
user_id,
&req.request_id,
);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn assume_role(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter RoleArn",
)
})?;
let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter RoleSessionName",
)
})?;
if role_session_name.len() > MAX_ROLE_SESSION_NAME_LENGTH {
return Err(AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationError",
format!(
"1 validation error detected: Value '{}' at 'roleSessionName' \
failed to satisfy constraint: Member must have length less than \
or equal to {}",
role_session_name, MAX_ROLE_SESSION_NAME_LENGTH
),
));
}
let partition = partition_for_region(&req.region);
let creds = StsCredentials::generate();
let mut state = self.state.write();
let account_id =
extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
let role_id = state
.roles
.get(role_name)
.map(|r| r.role_id.clone())
.unwrap_or_else(xml_responses::generate_role_id);
let assumed_role_arn = format!(
"arn:{}:sts::{}:assumed-role/{}/{}",
partition, account_id, role_name, role_session_name
);
let assumed_role_id = format!("{}:{}", role_id, role_session_name);
state.credential_identities.insert(
creds.access_key_id.clone(),
CredentialIdentity {
arn: assumed_role_arn,
user_id: assumed_role_id,
account_id: account_id.clone(),
},
);
let xml = xml_responses::assume_role_response(
role_arn,
role_session_name,
&role_id,
&account_id,
partition,
&creds,
&req.request_id,
);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn assume_role_with_web_identity(
&self,
req: &AwsRequest,
) -> Result<AwsResponse, AwsServiceError> {
let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter RoleArn",
)
})?;
let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter RoleSessionName",
)
})?;
let partition = partition_for_region(&req.region);
let creds = StsCredentials::generate();
let role_id = xml_responses::generate_role_id();
let mut state = self.state.write();
let account_id =
extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
let assumed_role_arn = format!(
"arn:{}:sts::{}:assumed-role/{}/{}",
partition, account_id, role_name, role_session_name
);
let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
state.credential_identities.insert(
creds.access_key_id.clone(),
CredentialIdentity {
arn: assumed_role_arn,
user_id: assumed_role_id_str,
account_id: account_id.clone(),
},
);
let xml = xml_responses::assume_role_with_web_identity_response(
role_arn,
role_session_name,
&account_id,
partition,
&creds,
&role_id,
&req.request_id,
);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn assume_role_with_saml(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter RoleArn",
)
})?;
let saml_assertion = req.query_params.get("SAMLAssertion").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter SAMLAssertion",
)
})?;
let role_session_name =
extract_saml_session_name(saml_assertion).unwrap_or_else(|| "saml-session".to_string());
let partition = partition_for_region(&req.region);
let creds = StsCredentials::generate();
let role_id = xml_responses::generate_role_id();
let mut state = self.state.write();
let account_id =
extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
let assumed_role_arn = format!(
"arn:{}:sts::{}:assumed-role/{}/{}",
partition, account_id, role_name, &role_session_name
);
let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
state.credential_identities.insert(
creds.access_key_id.clone(),
CredentialIdentity {
arn: assumed_role_arn,
user_id: assumed_role_id_str,
account_id: account_id.clone(),
},
);
let xml = xml_responses::assume_role_with_saml_response(
role_arn,
&role_session_name,
&account_id,
partition,
&creds,
&role_id,
&req.request_id,
);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn get_session_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let xml = xml_responses::get_session_token_response(&req.request_id);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn get_federation_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let name = req.query_params.get("Name").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter Name",
)
})?;
if let Some(policy) = req.query_params.get("Policy") {
if policy.len() > MAX_FEDERATION_TOKEN_POLICY_LENGTH {
return Err(AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"ValidationError",
format!(
"1 validation error detected: Value at 'policy' failed to \
satisfy constraint: Member must have length less than or \
equal to {}",
MAX_FEDERATION_TOKEN_POLICY_LENGTH
),
));
}
}
let partition = partition_for_region(&req.region);
let state = self.state.read();
let xml = xml_responses::get_federation_token_response(
name,
&state.account_id,
partition,
&req.request_id,
);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
fn get_access_key_info(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let access_key_id = req.query_params.get("AccessKeyId").ok_or_else(|| {
AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"MissingParameter",
"The request must contain the parameter AccessKeyId",
)
})?;
let state = self.state.read();
let account_id = state
.access_keys
.values()
.flatten()
.find(|k| k.access_key_id == *access_key_id)
.map(|_| state.account_id.clone())
.or_else(|| {
state
.credential_identities
.get(access_key_id.as_str())
.map(|ci| ci.account_id.clone())
})
.unwrap_or_else(|| state.account_id.clone());
let xml = xml_responses::get_access_key_info_response(&account_id, &req.request_id);
Ok(AwsResponse::xml(StatusCode::OK, xml))
}
}
fn extract_account_from_arn(arn: &str) -> Option<String> {
let parts: Vec<&str> = arn.split(':').collect();
if parts.len() >= 5 && !parts[4].is_empty() {
Some(parts[4].to_string())
} else {
None
}
}
fn extract_saml_session_name(saml_b64: &str) -> Option<String> {
use base64::Engine;
let decoded = base64::engine::general_purpose::STANDARD
.decode(saml_b64)
.ok()?;
let xml_str = String::from_utf8(decoded).ok()?;
let role_session_attr = "https://aws.amazon.com/SAML/Attributes/RoleSessionName";
let pos = xml_str.find(role_session_attr)?;
let after = &xml_str[pos..];
let av_start = after.find("AttributeValue")?;
let after_av = &after[av_start..];
let gt_pos = after_av.find('>')?;
let value_start = &after_av[gt_pos + 1..];
let lt_pos = value_start.find('<')?;
let value = value_start[..lt_pos].trim();
if value.is_empty() {
None
} else {
Some(value.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partition_for_region() {
assert_eq!(partition_for_region("us-east-1"), "aws");
assert_eq!(partition_for_region("eu-west-1"), "aws");
assert_eq!(partition_for_region("cn-north-1"), "aws-cn");
assert_eq!(partition_for_region("cn-northwest-1"), "aws-cn");
assert_eq!(partition_for_region("us-isob-east-1"), "aws-iso-b");
assert_eq!(partition_for_region("us-iso-east-1"), "aws-iso");
}
#[test]
fn test_extract_account_from_arn() {
assert_eq!(
extract_account_from_arn("arn:aws:iam::123456789012:role/test"),
Some("123456789012".to_string())
);
assert_eq!(
extract_account_from_arn("arn:aws:iam::111111111111:role/test"),
Some("111111111111".to_string())
);
assert_eq!(extract_account_from_arn("invalid"), None);
}
#[test]
fn test_extract_saml_session_name() {
use base64::Engine;
let xml = r#"<?xml version="1.0"?><samlp:Response><Assertion><AttributeStatement><Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><AttributeValue>testuser</AttributeValue></Attribute></AttributeStatement></Assertion></samlp:Response>"#;
let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
assert_eq!(
extract_saml_session_name(&encoded),
Some("testuser".to_string())
);
}
#[test]
fn test_extract_saml_session_name_with_namespace() {
use base64::Engine;
let xml = r#"<?xml version="1.0"?><samlp:Response><saml:Assertion><saml:AttributeStatement><saml:Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><saml:AttributeValue>testuser</saml:AttributeValue></saml:Attribute></saml:AttributeStatement></saml:Assertion></samlp:Response>"#;
let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
assert_eq!(
extract_saml_session_name(&encoded),
Some("testuser".to_string())
);
}
#[test]
fn test_session_token_format() {
let token = xml_responses::generate_session_token();
assert_eq!(token.len(), 356);
assert!(token.starts_with("FQoGZXIvYXdzE"));
}
#[test]
fn test_access_key_id_format() {
let key = xml_responses::generate_access_key_id();
assert_eq!(key.len(), 20);
assert!(key.starts_with("FSIA"));
}
#[test]
fn test_secret_access_key_format() {
let key = xml_responses::generate_secret_access_key();
assert_eq!(key.len(), 40);
}
#[test]
fn test_role_id_format() {
let id = xml_responses::generate_role_id();
assert_eq!(id.len(), 21);
assert!(id.starts_with("AROA"));
}
}