Skip to main content

fakecloud_iam/
sts_service.rs

1use async_trait::async_trait;
2use chrono::Utc;
3use http::StatusCode;
4
5use fakecloud_core::service::{AwsRequest, AwsResponse, AwsService, AwsServiceError};
6use fakecloud_core::validation::*;
7
8use crate::state::{CredentialIdentity, SharedIamState};
9use crate::xml_responses::{self, StsCredentials};
10
11/// Default duration for AssumeRole and similar operations (1 hour).
12const DEFAULT_ASSUME_ROLE_DURATION: i64 = 3600;
13
14/// Default duration for GetSessionToken (12 hours).
15const DEFAULT_SESSION_TOKEN_DURATION: i64 = 43200;
16
17/// Default duration for GetFederationToken (12 hours).
18const DEFAULT_FEDERATION_TOKEN_DURATION: i64 = 43200;
19
20/// Compute an ISO 8601 expiration timestamp from an optional DurationSeconds parameter.
21fn compute_expiration(req: &AwsRequest, default_duration: i64) -> Result<String, AwsServiceError> {
22    let duration = if let Some(ds) = req.query_params.get("DurationSeconds") {
23        ds.parse::<i64>().map_err(|_| {
24            AwsServiceError::aws_error(
25                StatusCode::BAD_REQUEST,
26                "ValidationError",
27                format!(
28                    "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
29                     Member must be a valid integer",
30                    ds
31                ),
32            )
33        })?
34    } else {
35        default_duration
36    };
37    let expiration = Utc::now() + chrono::Duration::seconds(duration);
38    Ok(expiration.format("%Y-%m-%dT%H:%M:%SZ").to_string())
39}
40
41pub struct StsService {
42    state: SharedIamState,
43}
44
45impl StsService {
46    pub fn new(state: SharedIamState) -> Self {
47        Self { state }
48    }
49}
50
51#[async_trait]
52impl AwsService for StsService {
53    fn service_name(&self) -> &str {
54        "sts"
55    }
56
57    async fn handle(&self, req: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
58        match req.action.as_str() {
59            "GetCallerIdentity" => self.get_caller_identity(&req),
60            "AssumeRole" => self.assume_role(&req),
61            "AssumeRoleWithWebIdentity" => self.assume_role_with_web_identity(&req),
62            "AssumeRoleWithSAML" => self.assume_role_with_saml(&req),
63            "GetSessionToken" => self.get_session_token(&req),
64            "GetFederationToken" => self.get_federation_token(&req),
65            "GetAccessKeyInfo" => self.get_access_key_info(&req),
66            "DecodeAuthorizationMessage" => self.decode_authorization_message(&req),
67            _ => Err(AwsServiceError::action_not_implemented("sts", &req.action)),
68        }
69    }
70
71    fn supported_actions(&self) -> &[&str] {
72        &[
73            "GetCallerIdentity",
74            "AssumeRole",
75            "AssumeRoleWithWebIdentity",
76            "AssumeRoleWithSAML",
77            "GetSessionToken",
78            "GetFederationToken",
79            "GetAccessKeyInfo",
80            "DecodeAuthorizationMessage",
81        ]
82    }
83}
84
85/// Get the AWS partition from a region string.
86fn partition_for_region(region: &str) -> &str {
87    if region.starts_with("cn-") {
88        "aws-cn"
89    } else if region.starts_with("us-iso-") {
90        "aws-iso"
91    } else if region.starts_with("us-isob-") {
92        "aws-iso-b"
93    } else if region.starts_with("us-isof-") {
94        "aws-iso-f"
95    } else if region.starts_with("eu-isoe-") {
96        "aws-iso-e"
97    } else {
98        "aws"
99    }
100}
101
102/// Extract the caller's access key from the SigV4 Authorization header.
103fn extract_access_key(req: &AwsRequest) -> Option<String> {
104    let auth = req.headers.get("authorization")?.to_str().ok()?;
105    let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
106    Some(info.access_key)
107}
108
109impl StsService {
110    fn get_caller_identity(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
111        let state = self.state.read();
112        let partition = partition_for_region(&req.region);
113
114        // Check if caller has credentials that map to a known identity
115        if let Some(access_key) = extract_access_key(req) {
116            // First check credential_identities (assumed roles, etc.)
117            if let Some(identity) = state.credential_identities.get(&access_key) {
118                let xml = xml_responses::get_caller_identity_response(
119                    &identity.account_id,
120                    &identity.arn,
121                    &identity.user_id,
122                    &req.request_id,
123                );
124                return Ok(AwsResponse::xml(StatusCode::OK, xml));
125            }
126
127            // Then check IAM user access keys
128            for keys in state.access_keys.values() {
129                for key in keys {
130                    if key.access_key_id == access_key {
131                        if let Some(user) = state.users.get(&key.user_name) {
132                            let xml = xml_responses::get_caller_identity_response(
133                                &state.account_id,
134                                &user.arn,
135                                &user.user_id,
136                                &req.request_id,
137                            );
138                            return Ok(AwsResponse::xml(StatusCode::OK, xml));
139                        }
140                    }
141                }
142            }
143        }
144
145        // Default identity — matches real AWS root credentials
146        let arn = format!("arn:{}:iam::{}:root", partition, state.account_id);
147        let user_id = "FKIAIOSFODNN7EXAMPLE";
148        let xml = xml_responses::get_caller_identity_response(
149            &state.account_id,
150            &arn,
151            user_id,
152            &req.request_id,
153        );
154        Ok(AwsResponse::xml(StatusCode::OK, xml))
155    }
156
157    fn assume_role(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
158        let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
159            AwsServiceError::aws_error(
160                StatusCode::BAD_REQUEST,
161                "MissingParameter",
162                "The request must contain the parameter RoleArn",
163            )
164        })?;
165        validate_string_length("roleArn", role_arn, 20, 2048)?;
166
167        let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
168            AwsServiceError::aws_error(
169                StatusCode::BAD_REQUEST,
170                "MissingParameter",
171                "The request must contain the parameter RoleSessionName",
172            )
173        })?;
174        validate_string_length("roleSessionName", role_session_name, 2, 64)?;
175
176        // Validate optional DurationSeconds (used below for expiration)
177        if let Some(ds) = req.query_params.get("DurationSeconds") {
178            let v = ds.parse::<i64>().map_err(|_| {
179                AwsServiceError::aws_error(
180                    StatusCode::BAD_REQUEST,
181                    "ValidationError",
182                    format!(
183                        "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
184                         Member must be a valid integer",
185                        ds
186                    ),
187                )
188            })?;
189            validate_range_i64("durationSeconds", v, 900, 43200)?;
190        }
191
192        // Validate optional ExternalId
193        validate_optional_string_length(
194            "externalId",
195            req.query_params.get("ExternalId").map(|s| s.as_str()),
196            2,
197            1224,
198        )?;
199
200        // Validate optional Policy
201        validate_optional_string_length(
202            "policy",
203            req.query_params.get("Policy").map(|s| s.as_str()),
204            1,
205            2048,
206        )?;
207
208        // Validate optional SourceIdentity
209        validate_optional_string_length(
210            "sourceIdentity",
211            req.query_params.get("SourceIdentity").map(|s| s.as_str()),
212            2,
213            64,
214        )?;
215
216        // Validate and accept optional MFA SerialNumber
217        validate_optional_string_length(
218            "serialNumber",
219            req.query_params.get("SerialNumber").map(|s| s.as_str()),
220            9,
221            256,
222        )?;
223        let serial_number = req.query_params.get("SerialNumber").cloned();
224
225        // Validate and accept optional MFA TokenCode
226        validate_optional_string_length(
227            "tokenCode",
228            req.query_params.get("TokenCode").map(|s| s.as_str()),
229            6,
230            6,
231        )?;
232        let token_code = req.query_params.get("TokenCode").cloned();
233
234        // Compute expiration from DurationSeconds (default 3600s)
235        let expiration = compute_expiration(req, DEFAULT_ASSUME_ROLE_DURATION)?;
236
237        // Accept MFA parameters without verification (emulator behavior)
238        let _mfa_serial = serial_number;
239        let _mfa_token = token_code;
240
241        let partition = partition_for_region(&req.region);
242        let creds = StsCredentials::generate();
243
244        let mut state = self.state.write();
245
246        // Extract account ID from role ARN if present, otherwise use default
247        let account_id =
248            extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
249
250        // Try to find the role in state to get its role_id
251        let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
252        let role_id = state
253            .roles
254            .get(role_name)
255            .map(|r| r.role_id.clone())
256            .unwrap_or_else(xml_responses::generate_role_id);
257
258        let assumed_role_arn = format!(
259            "arn:{}:sts::{}:assumed-role/{}/{}",
260            partition, account_id, role_name, role_session_name
261        );
262        let assumed_role_id = format!("{}:{}", role_id, role_session_name);
263
264        // Store credential identity for GetCallerIdentity lookups
265        state.credential_identities.insert(
266            creds.access_key_id.clone(),
267            CredentialIdentity {
268                arn: assumed_role_arn,
269                user_id: assumed_role_id,
270                account_id: account_id.clone(),
271            },
272        );
273
274        let xml = xml_responses::assume_role_response(
275            role_arn,
276            role_session_name,
277            &role_id,
278            &account_id,
279            partition,
280            &creds,
281            &expiration,
282            &req.request_id,
283        );
284        Ok(AwsResponse::xml(StatusCode::OK, xml))
285    }
286
287    fn assume_role_with_web_identity(
288        &self,
289        req: &AwsRequest,
290    ) -> Result<AwsResponse, AwsServiceError> {
291        let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
292            AwsServiceError::aws_error(
293                StatusCode::BAD_REQUEST,
294                "MissingParameter",
295                "The request must contain the parameter RoleArn",
296            )
297        })?;
298        validate_string_length("roleArn", role_arn, 20, 2048)?;
299
300        let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
301            AwsServiceError::aws_error(
302                StatusCode::BAD_REQUEST,
303                "MissingParameter",
304                "The request must contain the parameter RoleSessionName",
305            )
306        })?;
307        validate_string_length("roleSessionName", role_session_name, 2, 64)?;
308
309        // WebIdentityToken is required
310        let web_identity_token = req.query_params.get("WebIdentityToken").ok_or_else(|| {
311            AwsServiceError::aws_error(
312                StatusCode::BAD_REQUEST,
313                "MissingParameter",
314                "The request must contain the parameter WebIdentityToken",
315            )
316        })?;
317        validate_string_length("webIdentityToken", web_identity_token, 4, 20000)?;
318        let _web_identity_token = web_identity_token.clone();
319
320        // Validate optional Policy
321        validate_optional_string_length(
322            "policy",
323            req.query_params.get("Policy").map(|s| s.as_str()),
324            1,
325            2048,
326        )?;
327
328        // Validate optional ProviderId
329        validate_optional_string_length(
330            "providerId",
331            req.query_params.get("ProviderId").map(|s| s.as_str()),
332            4,
333            2048,
334        )?;
335
336        // Validate optional DurationSeconds (used below for expiration)
337        if let Some(ds) = req.query_params.get("DurationSeconds") {
338            let v = ds.parse::<i64>().map_err(|_| {
339                AwsServiceError::aws_error(
340                    StatusCode::BAD_REQUEST,
341                    "ValidationError",
342                    format!(
343                        "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
344                         Member must be a valid integer",
345                        ds
346                    ),
347                )
348            })?;
349            validate_range_i64("durationSeconds", v, 900, 43200)?;
350        }
351
352        // Compute expiration from DurationSeconds (default 3600s)
353        let expiration = compute_expiration(req, DEFAULT_ASSUME_ROLE_DURATION)?;
354
355        let partition = partition_for_region(&req.region);
356        let creds = StsCredentials::generate();
357        let role_id = xml_responses::generate_role_id();
358
359        let mut state = self.state.write();
360        let account_id =
361            extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
362
363        let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
364        let assumed_role_arn = format!(
365            "arn:{}:sts::{}:assumed-role/{}/{}",
366            partition, account_id, role_name, role_session_name
367        );
368        let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
369
370        state.credential_identities.insert(
371            creds.access_key_id.clone(),
372            CredentialIdentity {
373                arn: assumed_role_arn,
374                user_id: assumed_role_id_str,
375                account_id: account_id.clone(),
376            },
377        );
378
379        let xml = xml_responses::assume_role_with_web_identity_response(
380            role_arn,
381            role_session_name,
382            &account_id,
383            partition,
384            &creds,
385            &role_id,
386            &expiration,
387            &req.request_id,
388        );
389        Ok(AwsResponse::xml(StatusCode::OK, xml))
390    }
391
392    fn assume_role_with_saml(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
393        let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
394            AwsServiceError::aws_error(
395                StatusCode::BAD_REQUEST,
396                "MissingParameter",
397                "The request must contain the parameter RoleArn",
398            )
399        })?;
400        validate_string_length("roleArn", role_arn, 20, 2048)?;
401
402        // PrincipalArn is required
403        let principal_arn = req.query_params.get("PrincipalArn").ok_or_else(|| {
404            AwsServiceError::aws_error(
405                StatusCode::BAD_REQUEST,
406                "MissingParameter",
407                "The request must contain the parameter PrincipalArn",
408            )
409        })?;
410        validate_string_length("principalArn", principal_arn, 20, 2048)?;
411        let _principal_arn = principal_arn.clone();
412
413        // SAMLAssertion is required but we just need to extract session name from it
414        let saml_assertion = req.query_params.get("SAMLAssertion").ok_or_else(|| {
415            AwsServiceError::aws_error(
416                StatusCode::BAD_REQUEST,
417                "MissingParameter",
418                "The request must contain the parameter SAMLAssertion",
419            )
420        })?;
421        validate_string_length("sAMLAssertion", saml_assertion, 4, 100000)?;
422
423        // Validate optional Policy
424        validate_optional_string_length(
425            "policy",
426            req.query_params.get("Policy").map(|s| s.as_str()),
427            1,
428            2048,
429        )?;
430
431        // Validate optional DurationSeconds (used below for expiration)
432        if let Some(ds) = req.query_params.get("DurationSeconds") {
433            let v = ds.parse::<i64>().map_err(|_| {
434                AwsServiceError::aws_error(
435                    StatusCode::BAD_REQUEST,
436                    "ValidationError",
437                    format!(
438                        "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
439                         Member must be a valid integer",
440                        ds
441                    ),
442                )
443            })?;
444            validate_range_i64("durationSeconds", v, 900, 43200)?;
445        }
446
447        // Compute expiration from DurationSeconds (default 3600s)
448        let expiration = compute_expiration(req, DEFAULT_ASSUME_ROLE_DURATION)?;
449
450        // Decode the SAML assertion to extract the RoleSessionName
451        let role_session_name =
452            extract_saml_session_name(saml_assertion).unwrap_or_else(|| "saml-session".to_string());
453
454        let partition = partition_for_region(&req.region);
455        let creds = StsCredentials::generate();
456        let role_id = xml_responses::generate_role_id();
457
458        let mut state = self.state.write();
459        let account_id =
460            extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
461
462        let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
463        let assumed_role_arn = format!(
464            "arn:{}:sts::{}:assumed-role/{}/{}",
465            partition, account_id, role_name, &role_session_name
466        );
467        let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
468
469        state.credential_identities.insert(
470            creds.access_key_id.clone(),
471            CredentialIdentity {
472                arn: assumed_role_arn,
473                user_id: assumed_role_id_str,
474                account_id: account_id.clone(),
475            },
476        );
477
478        let xml = xml_responses::assume_role_with_saml_response(
479            role_arn,
480            &role_session_name,
481            &account_id,
482            partition,
483            &creds,
484            &role_id,
485            &expiration,
486            &req.request_id,
487        );
488        Ok(AwsResponse::xml(StatusCode::OK, xml))
489    }
490
491    fn get_session_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
492        // Validate optional DurationSeconds (used below for expiration)
493        if let Some(ds) = req.query_params.get("DurationSeconds") {
494            let v = ds.parse::<i64>().map_err(|_| {
495                AwsServiceError::aws_error(
496                    StatusCode::BAD_REQUEST,
497                    "ValidationError",
498                    format!(
499                        "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
500                         Member must be a valid integer",
501                        ds
502                    ),
503                )
504            })?;
505            validate_range_i64("durationSeconds", v, 900, 129600)?;
506        }
507
508        // Validate and accept optional MFA SerialNumber (no verification in emulator)
509        validate_optional_string_length(
510            "serialNumber",
511            req.query_params.get("SerialNumber").map(|s| s.as_str()),
512            9,
513            256,
514        )?;
515        let _serial_number = req.query_params.get("SerialNumber").cloned();
516
517        // Validate and accept optional MFA TokenCode (no verification in emulator)
518        validate_optional_string_length(
519            "tokenCode",
520            req.query_params.get("TokenCode").map(|s| s.as_str()),
521            6,
522            6,
523        )?;
524        let _token_code = req.query_params.get("TokenCode").cloned();
525
526        // Compute expiration from DurationSeconds (default 43200s / 12 hours)
527        let expiration = compute_expiration(req, DEFAULT_SESSION_TOKEN_DURATION)?;
528
529        let xml = xml_responses::get_session_token_response(&expiration, &req.request_id);
530        Ok(AwsResponse::xml(StatusCode::OK, xml))
531    }
532
533    fn get_federation_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
534        let name = req.query_params.get("Name").ok_or_else(|| {
535            AwsServiceError::aws_error(
536                StatusCode::BAD_REQUEST,
537                "MissingParameter",
538                "The request must contain the parameter Name",
539            )
540        })?;
541        validate_string_length("name", name, 2, 32)?;
542
543        // Validate optional DurationSeconds (used below for expiration)
544        if let Some(ds) = req.query_params.get("DurationSeconds") {
545            let v = ds.parse::<i64>().map_err(|_| {
546                AwsServiceError::aws_error(
547                    StatusCode::BAD_REQUEST,
548                    "ValidationError",
549                    format!(
550                        "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
551                         Member must be a valid integer",
552                        ds
553                    ),
554                )
555            })?;
556            validate_range_i64("durationSeconds", v, 900, 129600)?;
557        }
558
559        // Validate and store optional policy
560        validate_optional_string_length(
561            "policy",
562            req.query_params.get("Policy").map(|s| s.as_str()),
563            1,
564            2048,
565        )?;
566        let policy = req.query_params.get("Policy").cloned();
567
568        // Compute expiration from DurationSeconds (default 43200s / 12 hours)
569        let expiration = compute_expiration(req, DEFAULT_FEDERATION_TOKEN_DURATION)?;
570
571        let partition = partition_for_region(&req.region);
572        let state = self.state.read();
573        let xml = xml_responses::get_federation_token_response(
574            name,
575            &state.account_id,
576            partition,
577            &expiration,
578            policy.as_deref(),
579            &req.request_id,
580        );
581        Ok(AwsResponse::xml(StatusCode::OK, xml))
582    }
583
584    fn decode_authorization_message(
585        &self,
586        req: &AwsRequest,
587    ) -> Result<AwsResponse, AwsServiceError> {
588        let _encoded_message = req.query_params.get("EncodedMessage").ok_or_else(|| {
589            AwsServiceError::aws_error(
590                StatusCode::BAD_REQUEST,
591                "MissingParameter",
592                "The request must contain the parameter EncodedMessage",
593            )
594        })?;
595
596        let decoded_message =
597            r#"{"allowed":true,"explicitDeny":false,"matchedStatements":{"items":[]}}"#;
598        let xml =
599            xml_responses::decode_authorization_message_response(decoded_message, &req.request_id);
600        Ok(AwsResponse::xml(StatusCode::OK, xml))
601    }
602
603    fn get_access_key_info(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
604        let access_key_id = req.query_params.get("AccessKeyId").ok_or_else(|| {
605            AwsServiceError::aws_error(
606                StatusCode::BAD_REQUEST,
607                "MissingParameter",
608                "The request must contain the parameter AccessKeyId",
609            )
610        })?;
611        validate_string_length("accessKeyId", access_key_id, 16, 128)?;
612
613        // Try to resolve account from known access keys, fall back to default
614        let state = self.state.read();
615        let account_id = state
616            .access_keys
617            .values()
618            .flatten()
619            .find(|k| k.access_key_id == *access_key_id)
620            .map(|_| state.account_id.clone())
621            .or_else(|| {
622                state
623                    .credential_identities
624                    .get(access_key_id.as_str())
625                    .map(|ci| ci.account_id.clone())
626            })
627            .unwrap_or_else(|| state.account_id.clone());
628
629        let xml = xml_responses::get_access_key_info_response(&account_id, &req.request_id);
630        Ok(AwsResponse::xml(StatusCode::OK, xml))
631    }
632}
633
634/// Extract account ID from an ARN like `arn:aws:iam::123456789012:role/name`.
635fn extract_account_from_arn(arn: &str) -> Option<String> {
636    let parts: Vec<&str> = arn.split(':').collect();
637    if parts.len() >= 5 && !parts[4].is_empty() {
638        Some(parts[4].to_string())
639    } else {
640        None
641    }
642}
643
644/// Extract the RoleSessionName from a base64-encoded SAML assertion.
645fn extract_saml_session_name(saml_b64: &str) -> Option<String> {
646    use base64::Engine;
647    let decoded = base64::engine::general_purpose::STANDARD
648        .decode(saml_b64)
649        .ok()?;
650    let xml_str = String::from_utf8(decoded).ok()?;
651
652    // Look for the RoleSessionName attribute value in the SAML XML.
653    let role_session_attr = "https://aws.amazon.com/SAML/Attributes/RoleSessionName";
654    let pos = xml_str.find(role_session_attr)?;
655
656    // Find the AttributeValue after this position
657    let after = &xml_str[pos..];
658    let av_start = after.find("AttributeValue")?;
659    let after_av = &after[av_start..];
660    // Skip past the closing >
661    let gt_pos = after_av.find('>')?;
662    let value_start = &after_av[gt_pos + 1..];
663    // Find end of value (next < which starts the closing tag)
664    let lt_pos = value_start.find('<')?;
665    let value = value_start[..lt_pos].trim();
666
667    if value.is_empty() {
668        None
669    } else {
670        Some(value.to_string())
671    }
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677
678    #[test]
679    fn test_partition_for_region() {
680        assert_eq!(partition_for_region("us-east-1"), "aws");
681        assert_eq!(partition_for_region("eu-west-1"), "aws");
682        assert_eq!(partition_for_region("cn-north-1"), "aws-cn");
683        assert_eq!(partition_for_region("cn-northwest-1"), "aws-cn");
684        assert_eq!(partition_for_region("us-isob-east-1"), "aws-iso-b");
685        assert_eq!(partition_for_region("us-iso-east-1"), "aws-iso");
686    }
687
688    #[test]
689    fn test_extract_account_from_arn() {
690        assert_eq!(
691            extract_account_from_arn("arn:aws:iam::123456789012:role/test"),
692            Some("123456789012".to_string())
693        );
694        assert_eq!(
695            extract_account_from_arn("arn:aws:iam::111111111111:role/test"),
696            Some("111111111111".to_string())
697        );
698        assert_eq!(extract_account_from_arn("invalid"), None);
699    }
700
701    #[test]
702    fn test_extract_saml_session_name() {
703        use base64::Engine;
704        let xml = r#"<?xml version="1.0"?><samlp:Response><Assertion><AttributeStatement><Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><AttributeValue>testuser</AttributeValue></Attribute></AttributeStatement></Assertion></samlp:Response>"#;
705        let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
706        assert_eq!(
707            extract_saml_session_name(&encoded),
708            Some("testuser".to_string())
709        );
710    }
711
712    #[test]
713    fn test_extract_saml_session_name_with_namespace() {
714        use base64::Engine;
715        let xml = r#"<?xml version="1.0"?><samlp:Response><saml:Assertion><saml:AttributeStatement><saml:Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><saml:AttributeValue>testuser</saml:AttributeValue></saml:Attribute></saml:AttributeStatement></saml:Assertion></samlp:Response>"#;
716        let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
717        assert_eq!(
718            extract_saml_session_name(&encoded),
719            Some("testuser".to_string())
720        );
721    }
722
723    #[test]
724    fn test_session_token_format() {
725        let token = xml_responses::generate_session_token();
726        assert_eq!(token.len(), 356);
727        assert!(token.starts_with("FQoGZXIvYXdzE"));
728    }
729
730    #[test]
731    fn test_access_key_id_format() {
732        let key = xml_responses::generate_access_key_id();
733        assert_eq!(key.len(), 20);
734        assert!(key.starts_with("FSIA"));
735    }
736
737    #[test]
738    fn test_secret_access_key_format() {
739        let key = xml_responses::generate_secret_access_key();
740        assert_eq!(key.len(), 40);
741    }
742
743    #[test]
744    fn test_role_id_format() {
745        let id = xml_responses::generate_role_id();
746        assert_eq!(id.len(), 21);
747        assert!(id.starts_with("AROA"));
748    }
749
750    #[test]
751    fn test_decode_authorization_message() {
752        use crate::state::IamState;
753        use parking_lot::RwLock;
754        use std::collections::HashMap;
755        use std::sync::Arc;
756
757        let state: SharedIamState = Arc::new(RwLock::new(IamState::new("123456789012")));
758        let service = StsService::new(state);
759
760        let mut params = HashMap::new();
761        params.insert(
762            "EncodedMessage".to_string(),
763            "some-encoded-message".to_string(),
764        );
765
766        let req = make_test_request(params);
767        let resp = service.decode_authorization_message(&req).unwrap();
768        let body = std::str::from_utf8(&resp.body).unwrap();
769        assert!(body.contains("DecodedMessage"));
770        assert!(body.contains("allowed"));
771        assert!(body.contains("matchedStatements"));
772    }
773
774    #[test]
775    fn test_decode_authorization_message_missing_param() {
776        use crate::state::IamState;
777        use parking_lot::RwLock;
778        use std::collections::HashMap;
779        use std::sync::Arc;
780
781        let state: SharedIamState = Arc::new(RwLock::new(IamState::new("123456789012")));
782        let service = StsService::new(state);
783
784        let req = make_test_request(HashMap::new());
785        let result = service.decode_authorization_message(&req);
786        assert!(result.is_err());
787        let err = result.err().unwrap();
788        let msg = format!("{:?}", err);
789        assert!(msg.contains("EncodedMessage"));
790    }
791
792    fn make_test_request(params: std::collections::HashMap<String, String>) -> AwsRequest {
793        AwsRequest {
794            service: "sts".into(),
795            action: "Test".into(),
796            region: "us-east-1".into(),
797            account_id: "123456789012".into(),
798            request_id: "test".into(),
799            headers: http::HeaderMap::new(),
800            query_params: params,
801            body: Default::default(),
802            path_segments: vec![],
803            raw_path: "/".into(),
804            method: http::Method::POST,
805            is_query_protocol: true,
806            access_key_id: None,
807        }
808    }
809
810    fn parse_expiration(s: &str) -> chrono::DateTime<Utc> {
811        chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ")
812            .expect("valid timestamp")
813            .and_utc()
814    }
815
816    #[test]
817    fn test_compute_expiration_with_duration() {
818        use std::collections::HashMap;
819
820        let mut params = HashMap::new();
821        params.insert("DurationSeconds".to_string(), "1800".to_string());
822        let req = make_test_request(params);
823
824        let now = Utc::now();
825        let exp_str = compute_expiration(&req, 3600).unwrap();
826        let exp_utc = parse_expiration(&exp_str);
827
828        // Should be ~1800s from now (using provided DurationSeconds, not default)
829        let diff = (exp_utc - now).num_seconds();
830        assert!(
831            (1798..=1802).contains(&diff),
832            "expected ~1800s duration, got {diff}s"
833        );
834    }
835
836    #[test]
837    fn test_compute_expiration_default() {
838        use std::collections::HashMap;
839
840        let req = make_test_request(HashMap::new());
841
842        let now = Utc::now();
843        let exp_str = compute_expiration(&req, 43200).unwrap();
844        let exp_utc = parse_expiration(&exp_str);
845
846        // Should be ~43200s (12 hours) from now using default
847        let diff = (exp_utc - now).num_seconds();
848        assert!(
849            (43198..=43202).contains(&diff),
850            "expected ~43200s duration, got {diff}s"
851        );
852    }
853
854    #[test]
855    fn test_compute_expiration_uses_provided_not_default() {
856        use std::collections::HashMap;
857
858        let mut params = HashMap::new();
859        params.insert("DurationSeconds".to_string(), "900".to_string());
860        let req = make_test_request(params);
861
862        let before = Utc::now();
863        let exp_str = compute_expiration(&req, 43200).unwrap();
864        let exp_utc = parse_expiration(&exp_str);
865
866        // Should use 900s, not the default 43200s
867        let expected = before + chrono::Duration::seconds(900);
868        let diff = (exp_utc - expected).num_seconds().abs();
869        assert!(
870            diff <= 2,
871            "expected ~900s duration, got diff={diff}s from expected"
872        );
873    }
874}