Skip to main content

fakecloud_iam/
sts_service.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use http::StatusCode;
4
5use fakecloud_core::service::{AwsRequest, AwsResponse, AwsService, AwsServiceError};
6use fakecloud_core::validation::*;
7
8use crate::state::{CredentialIdentity, SharedIamState, StsTempCredential};
9use crate::xml_responses::{self, StsCredentials};
10
11/// Default duration for AssumeRole and similar operations (1 hour).
12const DEFAULT_ASSUME_ROLE_DURATION: i64 = 3600;
13
14/// Default duration for GetSessionToken (12 hours).
15const DEFAULT_SESSION_TOKEN_DURATION: i64 = 43200;
16
17/// Default duration for GetFederationToken (12 hours).
18const DEFAULT_FEDERATION_TOKEN_DURATION: i64 = 43200;
19
20/// Compute an absolute expiration timestamp from an optional DurationSeconds parameter.
21fn compute_expiration_at(
22    req: &AwsRequest,
23    default_duration: i64,
24) -> Result<DateTime<Utc>, AwsServiceError> {
25    let duration = if let Some(ds) = req.query_params.get("DurationSeconds") {
26        ds.parse::<i64>().map_err(|_| {
27            AwsServiceError::aws_error(
28                StatusCode::BAD_REQUEST,
29                "ValidationError",
30                format!(
31                    "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
32                     Member must be a valid integer",
33                    ds
34                ),
35            )
36        })?
37    } else {
38        default_duration
39    };
40    Ok(Utc::now() + chrono::Duration::seconds(duration))
41}
42
43/// Format an expiration timestamp as the ISO 8601 string AWS returns.
44fn format_expiration(ts: DateTime<Utc>) -> String {
45    ts.format("%Y-%m-%dT%H:%M:%SZ").to_string()
46}
47
48/// Test-only wrapper around [`compute_expiration_at`] used by the existing
49/// duration unit tests.
50#[cfg(test)]
51fn compute_expiration(req: &AwsRequest, default_duration: i64) -> Result<String, AwsServiceError> {
52    Ok(format_expiration(compute_expiration_at(
53        req,
54        default_duration,
55    )?))
56}
57
58pub struct StsService {
59    state: SharedIamState,
60}
61
62impl StsService {
63    pub fn new(state: SharedIamState) -> Self {
64        Self { state }
65    }
66}
67
68#[async_trait]
69impl AwsService for StsService {
70    fn service_name(&self) -> &str {
71        "sts"
72    }
73
74    async fn handle(&self, req: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
75        match req.action.as_str() {
76            "GetCallerIdentity" => self.get_caller_identity(&req),
77            "AssumeRole" => self.assume_role(&req),
78            "AssumeRoleWithWebIdentity" => self.assume_role_with_web_identity(&req),
79            "AssumeRoleWithSAML" => self.assume_role_with_saml(&req),
80            "GetSessionToken" => self.get_session_token(&req),
81            "GetFederationToken" => self.get_federation_token(&req),
82            "GetAccessKeyInfo" => self.get_access_key_info(&req),
83            "DecodeAuthorizationMessage" => self.decode_authorization_message(&req),
84            _ => Err(AwsServiceError::action_not_implemented("sts", &req.action)),
85        }
86    }
87
88    fn supported_actions(&self) -> &[&str] {
89        &[
90            "GetCallerIdentity",
91            "AssumeRole",
92            "AssumeRoleWithWebIdentity",
93            "AssumeRoleWithSAML",
94            "GetSessionToken",
95            "GetFederationToken",
96            "GetAccessKeyInfo",
97            "DecodeAuthorizationMessage",
98        ]
99    }
100
101    /// STS opts into Phase 1 IAM enforcement.
102    fn iam_enforceable(&self) -> bool {
103        true
104    }
105
106    /// STS actions operate on `*` per AWS — see
107    /// <https://docs.aws.amazon.com/service-authorization/latest/reference/list_awssecuritytokenservice.html>.
108    /// `AssumeRole*` variants additionally carry a role ARN as the
109    /// target resource so policies can scope by role name.
110    fn iam_action_for(
111        &self,
112        request: &fakecloud_core::service::AwsRequest,
113    ) -> Option<fakecloud_core::auth::IamAction> {
114        let action: &'static str = match request.action.as_str() {
115            "GetCallerIdentity" => "GetCallerIdentity",
116            "AssumeRole" => "AssumeRole",
117            "AssumeRoleWithWebIdentity" => "AssumeRoleWithWebIdentity",
118            "AssumeRoleWithSAML" => "AssumeRoleWithSAML",
119            "GetSessionToken" => "GetSessionToken",
120            "GetFederationToken" => "GetFederationToken",
121            "GetAccessKeyInfo" => "GetAccessKeyInfo",
122            "DecodeAuthorizationMessage" => "DecodeAuthorizationMessage",
123            _ => return None,
124        };
125        let resource = match action {
126            "AssumeRole" | "AssumeRoleWithWebIdentity" | "AssumeRoleWithSAML" => request
127                .query_params
128                .get("RoleArn")
129                .cloned()
130                .unwrap_or_else(|| "*".to_string()),
131            _ => "*".to_string(),
132        };
133        Some(fakecloud_core::auth::IamAction {
134            service: "sts",
135            action,
136            resource,
137        })
138    }
139}
140
141/// Get the AWS partition from a region string.
142fn partition_for_region(region: &str) -> &str {
143    if region.starts_with("cn-") {
144        "aws-cn"
145    } else if region.starts_with("us-iso-") {
146        "aws-iso"
147    } else if region.starts_with("us-isob-") {
148        "aws-iso-b"
149    } else if region.starts_with("us-isof-") {
150        "aws-iso-f"
151    } else if region.starts_with("eu-isoe-") {
152        "aws-iso-e"
153    } else {
154        "aws"
155    }
156}
157
158/// Extract the caller's access key from the SigV4 Authorization header.
159fn extract_access_key(req: &AwsRequest) -> Option<String> {
160    let auth = req.headers.get("authorization")?.to_str().ok()?;
161    let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
162    Some(info.access_key)
163}
164
165impl StsService {
166    fn get_caller_identity(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
167        // Prefer the pre-resolved principal that dispatch attached via the
168        // credential resolver — avoids re-parsing the Authorization header
169        // and re-walking IamState. Falls back to the account-root identity
170        // when the caller didn't present resolvable credentials (e.g. the
171        // `test` root bypass, or no auth header at all).
172        if let Some(principal) = req.principal.as_ref() {
173            let xml = xml_responses::get_caller_identity_response(
174                &principal.account_id,
175                &principal.arn,
176                &principal.user_id,
177                &req.request_id,
178            );
179            return Ok(AwsResponse::xml(StatusCode::OK, xml));
180        }
181
182        let state = self.state.read();
183        let partition = partition_for_region(&req.region);
184        let arn = format!("arn:{}:iam::{}:root", partition, state.account_id);
185        let user_id = "FKIAIOSFODNN7EXAMPLE";
186        let xml = xml_responses::get_caller_identity_response(
187            &state.account_id,
188            &arn,
189            user_id,
190            &req.request_id,
191        );
192        Ok(AwsResponse::xml(StatusCode::OK, xml))
193    }
194
195    fn assume_role(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
196        let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
197            AwsServiceError::aws_error(
198                StatusCode::BAD_REQUEST,
199                "MissingParameter",
200                "The request must contain the parameter RoleArn",
201            )
202        })?;
203        validate_string_length("roleArn", role_arn, 20, 2048)?;
204
205        let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
206            AwsServiceError::aws_error(
207                StatusCode::BAD_REQUEST,
208                "MissingParameter",
209                "The request must contain the parameter RoleSessionName",
210            )
211        })?;
212        validate_string_length("roleSessionName", role_session_name, 2, 64)?;
213
214        // Validate optional DurationSeconds (used below for expiration)
215        if let Some(ds) = req.query_params.get("DurationSeconds") {
216            let v = ds.parse::<i64>().map_err(|_| {
217                AwsServiceError::aws_error(
218                    StatusCode::BAD_REQUEST,
219                    "ValidationError",
220                    format!(
221                        "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
222                         Member must be a valid integer",
223                        ds
224                    ),
225                )
226            })?;
227            validate_range_i64("durationSeconds", v, 900, 43200)?;
228        }
229
230        // Validate optional ExternalId
231        validate_optional_string_length(
232            "externalId",
233            req.query_params.get("ExternalId").map(|s| s.as_str()),
234            2,
235            1224,
236        )?;
237
238        // Validate optional Policy
239        validate_optional_string_length(
240            "policy",
241            req.query_params.get("Policy").map(|s| s.as_str()),
242            1,
243            2048,
244        )?;
245
246        // Validate optional SourceIdentity
247        validate_optional_string_length(
248            "sourceIdentity",
249            req.query_params.get("SourceIdentity").map(|s| s.as_str()),
250            2,
251            64,
252        )?;
253
254        // Validate and accept optional MFA SerialNumber
255        validate_optional_string_length(
256            "serialNumber",
257            req.query_params.get("SerialNumber").map(|s| s.as_str()),
258            9,
259            256,
260        )?;
261        let serial_number = req.query_params.get("SerialNumber").cloned();
262
263        // Validate and accept optional MFA TokenCode
264        validate_optional_string_length(
265            "tokenCode",
266            req.query_params.get("TokenCode").map(|s| s.as_str()),
267            6,
268            6,
269        )?;
270        let token_code = req.query_params.get("TokenCode").cloned();
271
272        // Compute expiration from DurationSeconds (default 3600s)
273        let expiration_at = compute_expiration_at(req, DEFAULT_ASSUME_ROLE_DURATION)?;
274        let expiration = format_expiration(expiration_at);
275
276        // Accept MFA parameters without verification (emulator behavior)
277        let _mfa_serial = serial_number;
278        let _mfa_token = token_code;
279
280        let partition = partition_for_region(&req.region);
281        let creds = StsCredentials::generate();
282
283        let mut state = self.state.write();
284
285        // Extract account ID from role ARN if present, otherwise use default
286        let account_id =
287            extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
288
289        // Try to find the role in state to get its role_id
290        let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
291        let role_id = state
292            .roles
293            .get(role_name)
294            .map(|r| r.role_id.clone())
295            .unwrap_or_else(xml_responses::generate_role_id);
296
297        let assumed_role_arn = format!(
298            "arn:{}:sts::{}:assumed-role/{}/{}",
299            partition, account_id, role_name, role_session_name
300        );
301        let assumed_role_id = format!("{}:{}", role_id, role_session_name);
302
303        // Store credential identity for GetCallerIdentity lookups
304        state.credential_identities.insert(
305            creds.access_key_id.clone(),
306            CredentialIdentity {
307                arn: assumed_role_arn.clone(),
308                user_id: assumed_role_id.clone(),
309                account_id: account_id.clone(),
310            },
311        );
312        // Store full temp credential so SigV4 verification (batch 3) and IAM
313        // enforcement (batch 4+) can look up the secret by access key ID.
314        state.sts_temp_credentials.insert(
315            creds.access_key_id.clone(),
316            StsTempCredential {
317                access_key_id: creds.access_key_id.clone(),
318                secret_access_key: creds.secret_access_key.clone(),
319                session_token: creds.session_token.clone(),
320                principal_arn: assumed_role_arn,
321                user_id: assumed_role_id,
322                account_id: account_id.clone(),
323                expiration: expiration_at,
324            },
325        );
326
327        let xml = xml_responses::assume_role_response(&xml_responses::AssumedRoleInfo {
328            role_arn,
329            role_session_name,
330            assumed_role_id: &role_id,
331            account_id: &account_id,
332            partition,
333            creds: &creds,
334            expiration: &expiration,
335            request_id: &req.request_id,
336        });
337        Ok(AwsResponse::xml(StatusCode::OK, xml))
338    }
339
340    fn assume_role_with_web_identity(
341        &self,
342        req: &AwsRequest,
343    ) -> Result<AwsResponse, AwsServiceError> {
344        let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
345            AwsServiceError::aws_error(
346                StatusCode::BAD_REQUEST,
347                "MissingParameter",
348                "The request must contain the parameter RoleArn",
349            )
350        })?;
351        validate_string_length("roleArn", role_arn, 20, 2048)?;
352
353        let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
354            AwsServiceError::aws_error(
355                StatusCode::BAD_REQUEST,
356                "MissingParameter",
357                "The request must contain the parameter RoleSessionName",
358            )
359        })?;
360        validate_string_length("roleSessionName", role_session_name, 2, 64)?;
361
362        // WebIdentityToken is required
363        let web_identity_token = req.query_params.get("WebIdentityToken").ok_or_else(|| {
364            AwsServiceError::aws_error(
365                StatusCode::BAD_REQUEST,
366                "MissingParameter",
367                "The request must contain the parameter WebIdentityToken",
368            )
369        })?;
370        validate_string_length("webIdentityToken", web_identity_token, 4, 20000)?;
371        let _web_identity_token = web_identity_token.clone();
372
373        // Validate optional Policy
374        validate_optional_string_length(
375            "policy",
376            req.query_params.get("Policy").map(|s| s.as_str()),
377            1,
378            2048,
379        )?;
380
381        // Validate optional ProviderId
382        validate_optional_string_length(
383            "providerId",
384            req.query_params.get("ProviderId").map(|s| s.as_str()),
385            4,
386            2048,
387        )?;
388
389        // Validate optional DurationSeconds (used below for expiration)
390        if let Some(ds) = req.query_params.get("DurationSeconds") {
391            let v = ds.parse::<i64>().map_err(|_| {
392                AwsServiceError::aws_error(
393                    StatusCode::BAD_REQUEST,
394                    "ValidationError",
395                    format!(
396                        "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
397                         Member must be a valid integer",
398                        ds
399                    ),
400                )
401            })?;
402            validate_range_i64("durationSeconds", v, 900, 43200)?;
403        }
404
405        // Compute expiration from DurationSeconds (default 3600s)
406        let expiration_at = compute_expiration_at(req, DEFAULT_ASSUME_ROLE_DURATION)?;
407        let expiration = format_expiration(expiration_at);
408
409        let partition = partition_for_region(&req.region);
410        let creds = StsCredentials::generate();
411        let role_id = xml_responses::generate_role_id();
412
413        let mut state = self.state.write();
414        let account_id =
415            extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
416
417        let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
418        let assumed_role_arn = format!(
419            "arn:{}:sts::{}:assumed-role/{}/{}",
420            partition, account_id, role_name, role_session_name
421        );
422        let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
423
424        state.credential_identities.insert(
425            creds.access_key_id.clone(),
426            CredentialIdentity {
427                arn: assumed_role_arn.clone(),
428                user_id: assumed_role_id_str.clone(),
429                account_id: account_id.clone(),
430            },
431        );
432        state.sts_temp_credentials.insert(
433            creds.access_key_id.clone(),
434            StsTempCredential {
435                access_key_id: creds.access_key_id.clone(),
436                secret_access_key: creds.secret_access_key.clone(),
437                session_token: creds.session_token.clone(),
438                principal_arn: assumed_role_arn,
439                user_id: assumed_role_id_str,
440                account_id: account_id.clone(),
441                expiration: expiration_at,
442            },
443        );
444
445        let xml = xml_responses::assume_role_with_web_identity_response(
446            &xml_responses::AssumedRoleInfo {
447                role_arn,
448                role_session_name,
449                assumed_role_id: &role_id,
450                account_id: &account_id,
451                partition,
452                creds: &creds,
453                expiration: &expiration,
454                request_id: &req.request_id,
455            },
456        );
457        Ok(AwsResponse::xml(StatusCode::OK, xml))
458    }
459
460    fn assume_role_with_saml(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
461        let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
462            AwsServiceError::aws_error(
463                StatusCode::BAD_REQUEST,
464                "MissingParameter",
465                "The request must contain the parameter RoleArn",
466            )
467        })?;
468        validate_string_length("roleArn", role_arn, 20, 2048)?;
469
470        // PrincipalArn is required
471        let principal_arn = req.query_params.get("PrincipalArn").ok_or_else(|| {
472            AwsServiceError::aws_error(
473                StatusCode::BAD_REQUEST,
474                "MissingParameter",
475                "The request must contain the parameter PrincipalArn",
476            )
477        })?;
478        validate_string_length("principalArn", principal_arn, 20, 2048)?;
479        let _principal_arn = principal_arn.clone();
480
481        // SAMLAssertion is required but we just need to extract session name from it
482        let saml_assertion = req.query_params.get("SAMLAssertion").ok_or_else(|| {
483            AwsServiceError::aws_error(
484                StatusCode::BAD_REQUEST,
485                "MissingParameter",
486                "The request must contain the parameter SAMLAssertion",
487            )
488        })?;
489        validate_string_length("sAMLAssertion", saml_assertion, 4, 100000)?;
490
491        // Validate optional Policy
492        validate_optional_string_length(
493            "policy",
494            req.query_params.get("Policy").map(|s| s.as_str()),
495            1,
496            2048,
497        )?;
498
499        // Validate optional DurationSeconds (used below for expiration)
500        if let Some(ds) = req.query_params.get("DurationSeconds") {
501            let v = ds.parse::<i64>().map_err(|_| {
502                AwsServiceError::aws_error(
503                    StatusCode::BAD_REQUEST,
504                    "ValidationError",
505                    format!(
506                        "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
507                         Member must be a valid integer",
508                        ds
509                    ),
510                )
511            })?;
512            validate_range_i64("durationSeconds", v, 900, 43200)?;
513        }
514
515        // Compute expiration from DurationSeconds (default 3600s)
516        let expiration_at = compute_expiration_at(req, DEFAULT_ASSUME_ROLE_DURATION)?;
517        let expiration = format_expiration(expiration_at);
518
519        // Decode the SAML assertion to extract the RoleSessionName
520        let role_session_name =
521            extract_saml_session_name(saml_assertion).unwrap_or_else(|| "saml-session".to_string());
522
523        let partition = partition_for_region(&req.region);
524        let creds = StsCredentials::generate();
525        let role_id = xml_responses::generate_role_id();
526
527        let mut state = self.state.write();
528        let account_id =
529            extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
530
531        let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
532        let assumed_role_arn = format!(
533            "arn:{}:sts::{}:assumed-role/{}/{}",
534            partition, account_id, role_name, &role_session_name
535        );
536        let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
537
538        state.credential_identities.insert(
539            creds.access_key_id.clone(),
540            CredentialIdentity {
541                arn: assumed_role_arn.clone(),
542                user_id: assumed_role_id_str.clone(),
543                account_id: account_id.clone(),
544            },
545        );
546        state.sts_temp_credentials.insert(
547            creds.access_key_id.clone(),
548            StsTempCredential {
549                access_key_id: creds.access_key_id.clone(),
550                secret_access_key: creds.secret_access_key.clone(),
551                session_token: creds.session_token.clone(),
552                principal_arn: assumed_role_arn,
553                user_id: assumed_role_id_str,
554                account_id: account_id.clone(),
555                expiration: expiration_at,
556            },
557        );
558
559        let xml = xml_responses::assume_role_with_saml_response(&xml_responses::AssumedRoleInfo {
560            role_arn,
561            role_session_name: &role_session_name,
562            assumed_role_id: &role_id,
563            account_id: &account_id,
564            partition,
565            creds: &creds,
566            expiration: &expiration,
567            request_id: &req.request_id,
568        });
569        Ok(AwsResponse::xml(StatusCode::OK, xml))
570    }
571
572    fn get_session_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
573        // Validate optional DurationSeconds (used below for expiration)
574        if let Some(ds) = req.query_params.get("DurationSeconds") {
575            let v = ds.parse::<i64>().map_err(|_| {
576                AwsServiceError::aws_error(
577                    StatusCode::BAD_REQUEST,
578                    "ValidationError",
579                    format!(
580                        "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
581                         Member must be a valid integer",
582                        ds
583                    ),
584                )
585            })?;
586            validate_range_i64("durationSeconds", v, 900, 129600)?;
587        }
588
589        // Validate and accept optional MFA SerialNumber (no verification in emulator)
590        validate_optional_string_length(
591            "serialNumber",
592            req.query_params.get("SerialNumber").map(|s| s.as_str()),
593            9,
594            256,
595        )?;
596        let _serial_number = req.query_params.get("SerialNumber").cloned();
597
598        // Validate and accept optional MFA TokenCode (no verification in emulator)
599        validate_optional_string_length(
600            "tokenCode",
601            req.query_params.get("TokenCode").map(|s| s.as_str()),
602            6,
603            6,
604        )?;
605        let _token_code = req.query_params.get("TokenCode").cloned();
606
607        // Compute expiration from DurationSeconds (default 43200s / 12 hours)
608        let expiration_at = compute_expiration_at(req, DEFAULT_SESSION_TOKEN_DURATION)?;
609        let expiration = format_expiration(expiration_at);
610
611        // Resolve the calling principal so the temporary credential is tied
612        // to a real identity that SigV4 verification and IAM enforcement can
613        // look up later. Falls back to the account root when the caller
614        // isn't a known IAM user — matches how `GetSessionToken` behaves
615        // against AWS with root credentials.
616        let partition = partition_for_region(&req.region);
617        let mut state = self.state.write();
618        let (principal_arn, user_id, account_id) =
619            if let Some(akid) = extract_access_key(req).as_deref() {
620                if let Some(lookup) = state.credential_secret(akid) {
621                    (lookup.principal_arn, lookup.user_id, lookup.account_id)
622                } else {
623                    (
624                        format!("arn:{}:iam::{}:root", partition, state.account_id),
625                        state.account_id.clone(),
626                        state.account_id.clone(),
627                    )
628                }
629            } else {
630                (
631                    format!("arn:{}:iam::{}:root", partition, state.account_id),
632                    state.account_id.clone(),
633                    state.account_id.clone(),
634                )
635            };
636
637        let creds = StsCredentials::generate();
638        state.credential_identities.insert(
639            creds.access_key_id.clone(),
640            CredentialIdentity {
641                arn: principal_arn.clone(),
642                user_id: user_id.clone(),
643                account_id: account_id.clone(),
644            },
645        );
646        state.sts_temp_credentials.insert(
647            creds.access_key_id.clone(),
648            StsTempCredential {
649                access_key_id: creds.access_key_id.clone(),
650                secret_access_key: creds.secret_access_key.clone(),
651                session_token: creds.session_token.clone(),
652                principal_arn,
653                user_id,
654                account_id,
655                expiration: expiration_at,
656            },
657        );
658
659        let xml = xml_responses::get_session_token_response(&creds, &expiration, &req.request_id);
660        Ok(AwsResponse::xml(StatusCode::OK, xml))
661    }
662
663    fn get_federation_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
664        let name = req.query_params.get("Name").ok_or_else(|| {
665            AwsServiceError::aws_error(
666                StatusCode::BAD_REQUEST,
667                "MissingParameter",
668                "The request must contain the parameter Name",
669            )
670        })?;
671        validate_string_length("name", name, 2, 32)?;
672
673        // Validate optional DurationSeconds (used below for expiration)
674        if let Some(ds) = req.query_params.get("DurationSeconds") {
675            let v = ds.parse::<i64>().map_err(|_| {
676                AwsServiceError::aws_error(
677                    StatusCode::BAD_REQUEST,
678                    "ValidationError",
679                    format!(
680                        "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
681                         Member must be a valid integer",
682                        ds
683                    ),
684                )
685            })?;
686            validate_range_i64("durationSeconds", v, 900, 129600)?;
687        }
688
689        // Validate and store optional policy
690        validate_optional_string_length(
691            "policy",
692            req.query_params.get("Policy").map(|s| s.as_str()),
693            1,
694            2048,
695        )?;
696        let policy = req.query_params.get("Policy").cloned();
697
698        // Compute expiration from DurationSeconds (default 43200s / 12 hours)
699        let expiration_at = compute_expiration_at(req, DEFAULT_FEDERATION_TOKEN_DURATION)?;
700        let expiration = format_expiration(expiration_at);
701
702        let partition = partition_for_region(&req.region);
703        let creds = StsCredentials::generate();
704
705        let mut state = self.state.write();
706        let account_id = state.account_id.clone();
707        let federated_user_arn = format!(
708            "arn:{}:sts::{}:federated-user/{}",
709            partition, account_id, name
710        );
711        let federated_user_id = format!("{}:{}", account_id, name);
712
713        state.credential_identities.insert(
714            creds.access_key_id.clone(),
715            CredentialIdentity {
716                arn: federated_user_arn.clone(),
717                user_id: federated_user_id.clone(),
718                account_id: account_id.clone(),
719            },
720        );
721        state.sts_temp_credentials.insert(
722            creds.access_key_id.clone(),
723            StsTempCredential {
724                access_key_id: creds.access_key_id.clone(),
725                secret_access_key: creds.secret_access_key.clone(),
726                session_token: creds.session_token.clone(),
727                principal_arn: federated_user_arn,
728                user_id: federated_user_id,
729                account_id: account_id.clone(),
730                expiration: expiration_at,
731            },
732        );
733
734        let xml = xml_responses::get_federation_token_response(
735            &creds,
736            name,
737            &account_id,
738            partition,
739            &expiration,
740            policy.as_deref(),
741            &req.request_id,
742        );
743        Ok(AwsResponse::xml(StatusCode::OK, xml))
744    }
745
746    fn decode_authorization_message(
747        &self,
748        req: &AwsRequest,
749    ) -> Result<AwsResponse, AwsServiceError> {
750        let encoded_message = req.query_params.get("EncodedMessage").ok_or_else(|| {
751            AwsServiceError::aws_error(
752                StatusCode::BAD_REQUEST,
753                "MissingParameter",
754                "The request must contain the parameter EncodedMessage",
755            )
756        })?;
757        validate_string_length("encodedMessage", encoded_message, 1, 10240)?;
758
759        let decoded_message =
760            r#"{"allowed":true,"explicitDeny":false,"matchedStatements":{"items":[]}}"#;
761        let xml =
762            xml_responses::decode_authorization_message_response(decoded_message, &req.request_id);
763        Ok(AwsResponse::xml(StatusCode::OK, xml))
764    }
765
766    fn get_access_key_info(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
767        let access_key_id = req.query_params.get("AccessKeyId").ok_or_else(|| {
768            AwsServiceError::aws_error(
769                StatusCode::BAD_REQUEST,
770                "MissingParameter",
771                "The request must contain the parameter AccessKeyId",
772            )
773        })?;
774        validate_string_length("accessKeyId", access_key_id, 16, 128)?;
775
776        // Try to resolve account from known access keys, fall back to default
777        let state = self.state.read();
778        let account_id = state
779            .access_keys
780            .values()
781            .flatten()
782            .find(|k| k.access_key_id == *access_key_id)
783            .map(|_| state.account_id.clone())
784            .or_else(|| {
785                state
786                    .credential_identities
787                    .get(access_key_id.as_str())
788                    .map(|ci| ci.account_id.clone())
789            })
790            .unwrap_or_else(|| state.account_id.clone());
791
792        let xml = xml_responses::get_access_key_info_response(&account_id, &req.request_id);
793        Ok(AwsResponse::xml(StatusCode::OK, xml))
794    }
795}
796
797/// Extract account ID from an ARN like `arn:aws:iam::123456789012:role/name`.
798fn extract_account_from_arn(arn: &str) -> Option<String> {
799    let parts: Vec<&str> = arn.split(':').collect();
800    if parts.len() >= 5 && !parts[4].is_empty() {
801        Some(parts[4].to_string())
802    } else {
803        None
804    }
805}
806
807/// Extract the RoleSessionName from a base64-encoded SAML assertion.
808fn extract_saml_session_name(saml_b64: &str) -> Option<String> {
809    use base64::Engine;
810    let decoded = base64::engine::general_purpose::STANDARD
811        .decode(saml_b64)
812        .ok()?;
813    let xml_str = String::from_utf8(decoded).ok()?;
814
815    // Look for the RoleSessionName attribute value in the SAML XML.
816    let role_session_attr = "https://aws.amazon.com/SAML/Attributes/RoleSessionName";
817    let pos = xml_str.find(role_session_attr)?;
818
819    // Find the AttributeValue after this position
820    let after = &xml_str[pos..];
821    let av_start = after.find("AttributeValue")?;
822    let after_av = &after[av_start..];
823    // Skip past the closing >
824    let gt_pos = after_av.find('>')?;
825    let value_start = &after_av[gt_pos + 1..];
826    // Find end of value (next < which starts the closing tag)
827    let lt_pos = value_start.find('<')?;
828    let value = value_start[..lt_pos].trim();
829
830    if value.is_empty() {
831        None
832    } else {
833        Some(value.to_string())
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840
841    #[test]
842    fn test_partition_for_region() {
843        assert_eq!(partition_for_region("us-east-1"), "aws");
844        assert_eq!(partition_for_region("eu-west-1"), "aws");
845        assert_eq!(partition_for_region("cn-north-1"), "aws-cn");
846        assert_eq!(partition_for_region("cn-northwest-1"), "aws-cn");
847        assert_eq!(partition_for_region("us-isob-east-1"), "aws-iso-b");
848        assert_eq!(partition_for_region("us-iso-east-1"), "aws-iso");
849    }
850
851    #[test]
852    fn test_extract_account_from_arn() {
853        assert_eq!(
854            extract_account_from_arn("arn:aws:iam::123456789012:role/test"),
855            Some("123456789012".to_string())
856        );
857        assert_eq!(
858            extract_account_from_arn("arn:aws:iam::111111111111:role/test"),
859            Some("111111111111".to_string())
860        );
861        assert_eq!(extract_account_from_arn("invalid"), None);
862    }
863
864    #[test]
865    fn test_extract_saml_session_name() {
866        use base64::Engine;
867        let xml = r#"<?xml version="1.0"?><samlp:Response><Assertion><AttributeStatement><Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><AttributeValue>testuser</AttributeValue></Attribute></AttributeStatement></Assertion></samlp:Response>"#;
868        let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
869        assert_eq!(
870            extract_saml_session_name(&encoded),
871            Some("testuser".to_string())
872        );
873    }
874
875    #[test]
876    fn test_extract_saml_session_name_with_namespace() {
877        use base64::Engine;
878        let xml = r#"<?xml version="1.0"?><samlp:Response><saml:Assertion><saml:AttributeStatement><saml:Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><saml:AttributeValue>testuser</saml:AttributeValue></saml:Attribute></saml:AttributeStatement></saml:Assertion></samlp:Response>"#;
879        let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
880        assert_eq!(
881            extract_saml_session_name(&encoded),
882            Some("testuser".to_string())
883        );
884    }
885
886    #[test]
887    fn test_session_token_format() {
888        let token = xml_responses::generate_session_token();
889        assert_eq!(token.len(), 356);
890        assert!(token.starts_with("FQoGZXIvYXdzE"));
891    }
892
893    #[test]
894    fn test_access_key_id_format() {
895        let key = xml_responses::generate_access_key_id();
896        assert_eq!(key.len(), 20);
897        assert!(key.starts_with("FSIA"));
898    }
899
900    #[test]
901    fn test_secret_access_key_format() {
902        let key = xml_responses::generate_secret_access_key();
903        assert_eq!(key.len(), 40);
904    }
905
906    #[test]
907    fn test_role_id_format() {
908        let id = xml_responses::generate_role_id();
909        assert_eq!(id.len(), 21);
910        assert!(id.starts_with("AROA"));
911    }
912
913    #[test]
914    fn test_decode_authorization_message() {
915        use crate::state::IamState;
916        use parking_lot::RwLock;
917        use std::collections::HashMap;
918        use std::sync::Arc;
919
920        let state: SharedIamState = Arc::new(RwLock::new(IamState::new("123456789012")));
921        let service = StsService::new(state);
922
923        let mut params = HashMap::new();
924        params.insert(
925            "EncodedMessage".to_string(),
926            "some-encoded-message".to_string(),
927        );
928
929        let req = make_test_request(params);
930        let resp = service.decode_authorization_message(&req).unwrap();
931        let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
932        assert!(body.contains("DecodedMessage"));
933        assert!(body.contains("allowed"));
934        assert!(body.contains("matchedStatements"));
935    }
936
937    #[test]
938    fn test_decode_authorization_message_missing_param() {
939        use crate::state::IamState;
940        use parking_lot::RwLock;
941        use std::collections::HashMap;
942        use std::sync::Arc;
943
944        let state: SharedIamState = Arc::new(RwLock::new(IamState::new("123456789012")));
945        let service = StsService::new(state);
946
947        let req = make_test_request(HashMap::new());
948        let result = service.decode_authorization_message(&req);
949        assert!(result.is_err());
950        let err = result.err().unwrap();
951        let msg = format!("{:?}", err);
952        assert!(msg.contains("EncodedMessage"));
953    }
954
955    fn make_test_request(params: std::collections::HashMap<String, String>) -> AwsRequest {
956        AwsRequest {
957            service: "sts".into(),
958            action: "Test".into(),
959            region: "us-east-1".into(),
960            account_id: "123456789012".into(),
961            request_id: "test".into(),
962            headers: http::HeaderMap::new(),
963            query_params: params,
964            body: Default::default(),
965            path_segments: vec![],
966            raw_path: "/".into(),
967            raw_query: String::new(),
968            method: http::Method::POST,
969            is_query_protocol: true,
970            access_key_id: None,
971            principal: None,
972        }
973    }
974
975    fn parse_expiration(s: &str) -> chrono::DateTime<Utc> {
976        chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ")
977            .expect("valid timestamp")
978            .and_utc()
979    }
980
981    #[test]
982    fn test_compute_expiration_with_duration() {
983        use std::collections::HashMap;
984
985        let mut params = HashMap::new();
986        params.insert("DurationSeconds".to_string(), "1800".to_string());
987        let req = make_test_request(params);
988
989        let now = Utc::now();
990        let exp_str = compute_expiration(&req, 3600).unwrap();
991        let exp_utc = parse_expiration(&exp_str);
992
993        // Should be ~1800s from now (using provided DurationSeconds, not default)
994        let diff = (exp_utc - now).num_seconds();
995        assert!(
996            (1798..=1802).contains(&diff),
997            "expected ~1800s duration, got {diff}s"
998        );
999    }
1000
1001    #[test]
1002    fn test_compute_expiration_default() {
1003        use std::collections::HashMap;
1004
1005        let req = make_test_request(HashMap::new());
1006
1007        let now = Utc::now();
1008        let exp_str = compute_expiration(&req, 43200).unwrap();
1009        let exp_utc = parse_expiration(&exp_str);
1010
1011        // Should be ~43200s (12 hours) from now using default
1012        let diff = (exp_utc - now).num_seconds();
1013        assert!(
1014            (43198..=43202).contains(&diff),
1015            "expected ~43200s duration, got {diff}s"
1016        );
1017    }
1018
1019    #[test]
1020    fn test_compute_expiration_uses_provided_not_default() {
1021        use std::collections::HashMap;
1022
1023        let mut params = HashMap::new();
1024        params.insert("DurationSeconds".to_string(), "900".to_string());
1025        let req = make_test_request(params);
1026
1027        let before = Utc::now();
1028        let exp_str = compute_expiration(&req, 43200).unwrap();
1029        let exp_utc = parse_expiration(&exp_str);
1030
1031        // Should use 900s, not the default 43200s
1032        let expected = before + chrono::Duration::seconds(900);
1033        let diff = (exp_utc - expected).num_seconds().abs();
1034        assert!(
1035            diff <= 2,
1036            "expected ~900s duration, got diff={diff}s from expected"
1037        );
1038    }
1039}