1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use http::StatusCode;
4
5use fakecloud_core::service::{AwsRequest, AwsResponse, AwsService, AwsServiceError};
6use fakecloud_core::validation::*;
7
8use crate::state::{CredentialIdentity, SharedIamState, StsTempCredential};
9use crate::xml_responses::{self, StsCredentials};
10
11const DEFAULT_ASSUME_ROLE_DURATION: i64 = 3600;
13
14const DEFAULT_SESSION_TOKEN_DURATION: i64 = 43200;
16
17const DEFAULT_FEDERATION_TOKEN_DURATION: i64 = 43200;
19
20fn compute_expiration_at(
22 req: &AwsRequest,
23 default_duration: i64,
24) -> Result<DateTime<Utc>, AwsServiceError> {
25 let duration = if let Some(ds) = req.query_params.get("DurationSeconds") {
26 ds.parse::<i64>().map_err(|_| {
27 AwsServiceError::aws_error(
28 StatusCode::BAD_REQUEST,
29 "ValidationError",
30 format!(
31 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
32 Member must be a valid integer",
33 ds
34 ),
35 )
36 })?
37 } else {
38 default_duration
39 };
40 Ok(Utc::now() + chrono::Duration::seconds(duration))
41}
42
43fn format_expiration(ts: DateTime<Utc>) -> String {
45 ts.format("%Y-%m-%dT%H:%M:%SZ").to_string()
46}
47
48#[cfg(test)]
51fn compute_expiration(req: &AwsRequest, default_duration: i64) -> Result<String, AwsServiceError> {
52 Ok(format_expiration(compute_expiration_at(
53 req,
54 default_duration,
55 )?))
56}
57
58pub struct StsService {
59 state: SharedIamState,
60}
61
62impl StsService {
63 pub fn new(state: SharedIamState) -> Self {
64 Self { state }
65 }
66}
67
68#[async_trait]
69impl AwsService for StsService {
70 fn service_name(&self) -> &str {
71 "sts"
72 }
73
74 async fn handle(&self, req: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
75 match req.action.as_str() {
76 "GetCallerIdentity" => self.get_caller_identity(&req),
77 "AssumeRole" => self.assume_role(&req),
78 "AssumeRoleWithWebIdentity" => self.assume_role_with_web_identity(&req),
79 "AssumeRoleWithSAML" => self.assume_role_with_saml(&req),
80 "GetSessionToken" => self.get_session_token(&req),
81 "GetFederationToken" => self.get_federation_token(&req),
82 "GetAccessKeyInfo" => self.get_access_key_info(&req),
83 "DecodeAuthorizationMessage" => self.decode_authorization_message(&req),
84 _ => Err(AwsServiceError::action_not_implemented("sts", &req.action)),
85 }
86 }
87
88 fn supported_actions(&self) -> &[&str] {
89 &[
90 "GetCallerIdentity",
91 "AssumeRole",
92 "AssumeRoleWithWebIdentity",
93 "AssumeRoleWithSAML",
94 "GetSessionToken",
95 "GetFederationToken",
96 "GetAccessKeyInfo",
97 "DecodeAuthorizationMessage",
98 ]
99 }
100
101 fn iam_enforceable(&self) -> bool {
103 true
104 }
105
106 fn iam_action_for(
111 &self,
112 request: &fakecloud_core::service::AwsRequest,
113 ) -> Option<fakecloud_core::auth::IamAction> {
114 let action: &'static str = match request.action.as_str() {
115 "GetCallerIdentity" => "GetCallerIdentity",
116 "AssumeRole" => "AssumeRole",
117 "AssumeRoleWithWebIdentity" => "AssumeRoleWithWebIdentity",
118 "AssumeRoleWithSAML" => "AssumeRoleWithSAML",
119 "GetSessionToken" => "GetSessionToken",
120 "GetFederationToken" => "GetFederationToken",
121 "GetAccessKeyInfo" => "GetAccessKeyInfo",
122 "DecodeAuthorizationMessage" => "DecodeAuthorizationMessage",
123 _ => return None,
124 };
125 let resource = match action {
126 "AssumeRole" | "AssumeRoleWithWebIdentity" | "AssumeRoleWithSAML" => request
127 .query_params
128 .get("RoleArn")
129 .cloned()
130 .unwrap_or_else(|| "*".to_string()),
131 _ => "*".to_string(),
132 };
133 Some(fakecloud_core::auth::IamAction {
134 service: "sts",
135 action,
136 resource,
137 })
138 }
139}
140
141fn partition_for_region(region: &str) -> &str {
143 if region.starts_with("cn-") {
144 "aws-cn"
145 } else if region.starts_with("us-iso-") {
146 "aws-iso"
147 } else if region.starts_with("us-isob-") {
148 "aws-iso-b"
149 } else if region.starts_with("us-isof-") {
150 "aws-iso-f"
151 } else if region.starts_with("eu-isoe-") {
152 "aws-iso-e"
153 } else {
154 "aws"
155 }
156}
157
158fn extract_access_key(req: &AwsRequest) -> Option<String> {
160 let auth = req.headers.get("authorization")?.to_str().ok()?;
161 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
162 Some(info.access_key)
163}
164
165impl StsService {
166 fn get_caller_identity(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
167 if let Some(principal) = req.principal.as_ref() {
173 let xml = xml_responses::get_caller_identity_response(
174 &principal.account_id,
175 &principal.arn,
176 &principal.user_id,
177 &req.request_id,
178 );
179 return Ok(AwsResponse::xml(StatusCode::OK, xml));
180 }
181
182 let state = self.state.read();
183 let partition = partition_for_region(&req.region);
184 let arn = format!("arn:{}:iam::{}:root", partition, state.account_id);
185 let user_id = "FKIAIOSFODNN7EXAMPLE";
186 let xml = xml_responses::get_caller_identity_response(
187 &state.account_id,
188 &arn,
189 user_id,
190 &req.request_id,
191 );
192 Ok(AwsResponse::xml(StatusCode::OK, xml))
193 }
194
195 fn assume_role(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
196 let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
197 AwsServiceError::aws_error(
198 StatusCode::BAD_REQUEST,
199 "MissingParameter",
200 "The request must contain the parameter RoleArn",
201 )
202 })?;
203 validate_string_length("roleArn", role_arn, 20, 2048)?;
204
205 let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
206 AwsServiceError::aws_error(
207 StatusCode::BAD_REQUEST,
208 "MissingParameter",
209 "The request must contain the parameter RoleSessionName",
210 )
211 })?;
212 validate_string_length("roleSessionName", role_session_name, 2, 64)?;
213
214 if let Some(ds) = req.query_params.get("DurationSeconds") {
216 let v = ds.parse::<i64>().map_err(|_| {
217 AwsServiceError::aws_error(
218 StatusCode::BAD_REQUEST,
219 "ValidationError",
220 format!(
221 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
222 Member must be a valid integer",
223 ds
224 ),
225 )
226 })?;
227 validate_range_i64("durationSeconds", v, 900, 43200)?;
228 }
229
230 validate_optional_string_length(
232 "externalId",
233 req.query_params.get("ExternalId").map(|s| s.as_str()),
234 2,
235 1224,
236 )?;
237
238 validate_optional_string_length(
240 "policy",
241 req.query_params.get("Policy").map(|s| s.as_str()),
242 1,
243 2048,
244 )?;
245
246 validate_optional_string_length(
248 "sourceIdentity",
249 req.query_params.get("SourceIdentity").map(|s| s.as_str()),
250 2,
251 64,
252 )?;
253
254 validate_optional_string_length(
256 "serialNumber",
257 req.query_params.get("SerialNumber").map(|s| s.as_str()),
258 9,
259 256,
260 )?;
261 let serial_number = req.query_params.get("SerialNumber").cloned();
262
263 validate_optional_string_length(
265 "tokenCode",
266 req.query_params.get("TokenCode").map(|s| s.as_str()),
267 6,
268 6,
269 )?;
270 let token_code = req.query_params.get("TokenCode").cloned();
271
272 let expiration_at = compute_expiration_at(req, DEFAULT_ASSUME_ROLE_DURATION)?;
274 let expiration = format_expiration(expiration_at);
275
276 let _mfa_serial = serial_number;
278 let _mfa_token = token_code;
279
280 let partition = partition_for_region(&req.region);
281 let creds = StsCredentials::generate();
282
283 let mut state = self.state.write();
284
285 let account_id =
287 extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
288
289 let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
291 let role_id = state
292 .roles
293 .get(role_name)
294 .map(|r| r.role_id.clone())
295 .unwrap_or_else(xml_responses::generate_role_id);
296
297 let assumed_role_arn = format!(
298 "arn:{}:sts::{}:assumed-role/{}/{}",
299 partition, account_id, role_name, role_session_name
300 );
301 let assumed_role_id = format!("{}:{}", role_id, role_session_name);
302
303 state.credential_identities.insert(
305 creds.access_key_id.clone(),
306 CredentialIdentity {
307 arn: assumed_role_arn.clone(),
308 user_id: assumed_role_id.clone(),
309 account_id: account_id.clone(),
310 },
311 );
312 state.sts_temp_credentials.insert(
315 creds.access_key_id.clone(),
316 StsTempCredential {
317 access_key_id: creds.access_key_id.clone(),
318 secret_access_key: creds.secret_access_key.clone(),
319 session_token: creds.session_token.clone(),
320 principal_arn: assumed_role_arn,
321 user_id: assumed_role_id,
322 account_id: account_id.clone(),
323 expiration: expiration_at,
324 },
325 );
326
327 let xml = xml_responses::assume_role_response(&xml_responses::AssumedRoleInfo {
328 role_arn,
329 role_session_name,
330 assumed_role_id: &role_id,
331 account_id: &account_id,
332 partition,
333 creds: &creds,
334 expiration: &expiration,
335 request_id: &req.request_id,
336 });
337 Ok(AwsResponse::xml(StatusCode::OK, xml))
338 }
339
340 fn assume_role_with_web_identity(
341 &self,
342 req: &AwsRequest,
343 ) -> Result<AwsResponse, AwsServiceError> {
344 let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
345 AwsServiceError::aws_error(
346 StatusCode::BAD_REQUEST,
347 "MissingParameter",
348 "The request must contain the parameter RoleArn",
349 )
350 })?;
351 validate_string_length("roleArn", role_arn, 20, 2048)?;
352
353 let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
354 AwsServiceError::aws_error(
355 StatusCode::BAD_REQUEST,
356 "MissingParameter",
357 "The request must contain the parameter RoleSessionName",
358 )
359 })?;
360 validate_string_length("roleSessionName", role_session_name, 2, 64)?;
361
362 let web_identity_token = req.query_params.get("WebIdentityToken").ok_or_else(|| {
364 AwsServiceError::aws_error(
365 StatusCode::BAD_REQUEST,
366 "MissingParameter",
367 "The request must contain the parameter WebIdentityToken",
368 )
369 })?;
370 validate_string_length("webIdentityToken", web_identity_token, 4, 20000)?;
371 let _web_identity_token = web_identity_token.clone();
372
373 validate_optional_string_length(
375 "policy",
376 req.query_params.get("Policy").map(|s| s.as_str()),
377 1,
378 2048,
379 )?;
380
381 validate_optional_string_length(
383 "providerId",
384 req.query_params.get("ProviderId").map(|s| s.as_str()),
385 4,
386 2048,
387 )?;
388
389 if let Some(ds) = req.query_params.get("DurationSeconds") {
391 let v = ds.parse::<i64>().map_err(|_| {
392 AwsServiceError::aws_error(
393 StatusCode::BAD_REQUEST,
394 "ValidationError",
395 format!(
396 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
397 Member must be a valid integer",
398 ds
399 ),
400 )
401 })?;
402 validate_range_i64("durationSeconds", v, 900, 43200)?;
403 }
404
405 let expiration_at = compute_expiration_at(req, DEFAULT_ASSUME_ROLE_DURATION)?;
407 let expiration = format_expiration(expiration_at);
408
409 let partition = partition_for_region(&req.region);
410 let creds = StsCredentials::generate();
411 let role_id = xml_responses::generate_role_id();
412
413 let mut state = self.state.write();
414 let account_id =
415 extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
416
417 let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
418 let assumed_role_arn = format!(
419 "arn:{}:sts::{}:assumed-role/{}/{}",
420 partition, account_id, role_name, role_session_name
421 );
422 let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
423
424 state.credential_identities.insert(
425 creds.access_key_id.clone(),
426 CredentialIdentity {
427 arn: assumed_role_arn.clone(),
428 user_id: assumed_role_id_str.clone(),
429 account_id: account_id.clone(),
430 },
431 );
432 state.sts_temp_credentials.insert(
433 creds.access_key_id.clone(),
434 StsTempCredential {
435 access_key_id: creds.access_key_id.clone(),
436 secret_access_key: creds.secret_access_key.clone(),
437 session_token: creds.session_token.clone(),
438 principal_arn: assumed_role_arn,
439 user_id: assumed_role_id_str,
440 account_id: account_id.clone(),
441 expiration: expiration_at,
442 },
443 );
444
445 let xml = xml_responses::assume_role_with_web_identity_response(
446 &xml_responses::AssumedRoleInfo {
447 role_arn,
448 role_session_name,
449 assumed_role_id: &role_id,
450 account_id: &account_id,
451 partition,
452 creds: &creds,
453 expiration: &expiration,
454 request_id: &req.request_id,
455 },
456 );
457 Ok(AwsResponse::xml(StatusCode::OK, xml))
458 }
459
460 fn assume_role_with_saml(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
461 let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
462 AwsServiceError::aws_error(
463 StatusCode::BAD_REQUEST,
464 "MissingParameter",
465 "The request must contain the parameter RoleArn",
466 )
467 })?;
468 validate_string_length("roleArn", role_arn, 20, 2048)?;
469
470 let principal_arn = req.query_params.get("PrincipalArn").ok_or_else(|| {
472 AwsServiceError::aws_error(
473 StatusCode::BAD_REQUEST,
474 "MissingParameter",
475 "The request must contain the parameter PrincipalArn",
476 )
477 })?;
478 validate_string_length("principalArn", principal_arn, 20, 2048)?;
479 let _principal_arn = principal_arn.clone();
480
481 let saml_assertion = req.query_params.get("SAMLAssertion").ok_or_else(|| {
483 AwsServiceError::aws_error(
484 StatusCode::BAD_REQUEST,
485 "MissingParameter",
486 "The request must contain the parameter SAMLAssertion",
487 )
488 })?;
489 validate_string_length("sAMLAssertion", saml_assertion, 4, 100000)?;
490
491 validate_optional_string_length(
493 "policy",
494 req.query_params.get("Policy").map(|s| s.as_str()),
495 1,
496 2048,
497 )?;
498
499 if let Some(ds) = req.query_params.get("DurationSeconds") {
501 let v = ds.parse::<i64>().map_err(|_| {
502 AwsServiceError::aws_error(
503 StatusCode::BAD_REQUEST,
504 "ValidationError",
505 format!(
506 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
507 Member must be a valid integer",
508 ds
509 ),
510 )
511 })?;
512 validate_range_i64("durationSeconds", v, 900, 43200)?;
513 }
514
515 let expiration_at = compute_expiration_at(req, DEFAULT_ASSUME_ROLE_DURATION)?;
517 let expiration = format_expiration(expiration_at);
518
519 let role_session_name =
521 extract_saml_session_name(saml_assertion).unwrap_or_else(|| "saml-session".to_string());
522
523 let partition = partition_for_region(&req.region);
524 let creds = StsCredentials::generate();
525 let role_id = xml_responses::generate_role_id();
526
527 let mut state = self.state.write();
528 let account_id =
529 extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
530
531 let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
532 let assumed_role_arn = format!(
533 "arn:{}:sts::{}:assumed-role/{}/{}",
534 partition, account_id, role_name, &role_session_name
535 );
536 let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
537
538 state.credential_identities.insert(
539 creds.access_key_id.clone(),
540 CredentialIdentity {
541 arn: assumed_role_arn.clone(),
542 user_id: assumed_role_id_str.clone(),
543 account_id: account_id.clone(),
544 },
545 );
546 state.sts_temp_credentials.insert(
547 creds.access_key_id.clone(),
548 StsTempCredential {
549 access_key_id: creds.access_key_id.clone(),
550 secret_access_key: creds.secret_access_key.clone(),
551 session_token: creds.session_token.clone(),
552 principal_arn: assumed_role_arn,
553 user_id: assumed_role_id_str,
554 account_id: account_id.clone(),
555 expiration: expiration_at,
556 },
557 );
558
559 let xml = xml_responses::assume_role_with_saml_response(&xml_responses::AssumedRoleInfo {
560 role_arn,
561 role_session_name: &role_session_name,
562 assumed_role_id: &role_id,
563 account_id: &account_id,
564 partition,
565 creds: &creds,
566 expiration: &expiration,
567 request_id: &req.request_id,
568 });
569 Ok(AwsResponse::xml(StatusCode::OK, xml))
570 }
571
572 fn get_session_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
573 if let Some(ds) = req.query_params.get("DurationSeconds") {
575 let v = ds.parse::<i64>().map_err(|_| {
576 AwsServiceError::aws_error(
577 StatusCode::BAD_REQUEST,
578 "ValidationError",
579 format!(
580 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
581 Member must be a valid integer",
582 ds
583 ),
584 )
585 })?;
586 validate_range_i64("durationSeconds", v, 900, 129600)?;
587 }
588
589 validate_optional_string_length(
591 "serialNumber",
592 req.query_params.get("SerialNumber").map(|s| s.as_str()),
593 9,
594 256,
595 )?;
596 let _serial_number = req.query_params.get("SerialNumber").cloned();
597
598 validate_optional_string_length(
600 "tokenCode",
601 req.query_params.get("TokenCode").map(|s| s.as_str()),
602 6,
603 6,
604 )?;
605 let _token_code = req.query_params.get("TokenCode").cloned();
606
607 let expiration_at = compute_expiration_at(req, DEFAULT_SESSION_TOKEN_DURATION)?;
609 let expiration = format_expiration(expiration_at);
610
611 let partition = partition_for_region(&req.region);
617 let mut state = self.state.write();
618 let (principal_arn, user_id, account_id) =
619 if let Some(akid) = extract_access_key(req).as_deref() {
620 if let Some(lookup) = state.credential_secret(akid) {
621 (lookup.principal_arn, lookup.user_id, lookup.account_id)
622 } else {
623 (
624 format!("arn:{}:iam::{}:root", partition, state.account_id),
625 state.account_id.clone(),
626 state.account_id.clone(),
627 )
628 }
629 } else {
630 (
631 format!("arn:{}:iam::{}:root", partition, state.account_id),
632 state.account_id.clone(),
633 state.account_id.clone(),
634 )
635 };
636
637 let creds = StsCredentials::generate();
638 state.credential_identities.insert(
639 creds.access_key_id.clone(),
640 CredentialIdentity {
641 arn: principal_arn.clone(),
642 user_id: user_id.clone(),
643 account_id: account_id.clone(),
644 },
645 );
646 state.sts_temp_credentials.insert(
647 creds.access_key_id.clone(),
648 StsTempCredential {
649 access_key_id: creds.access_key_id.clone(),
650 secret_access_key: creds.secret_access_key.clone(),
651 session_token: creds.session_token.clone(),
652 principal_arn,
653 user_id,
654 account_id,
655 expiration: expiration_at,
656 },
657 );
658
659 let xml = xml_responses::get_session_token_response(&creds, &expiration, &req.request_id);
660 Ok(AwsResponse::xml(StatusCode::OK, xml))
661 }
662
663 fn get_federation_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
664 let name = req.query_params.get("Name").ok_or_else(|| {
665 AwsServiceError::aws_error(
666 StatusCode::BAD_REQUEST,
667 "MissingParameter",
668 "The request must contain the parameter Name",
669 )
670 })?;
671 validate_string_length("name", name, 2, 32)?;
672
673 if let Some(ds) = req.query_params.get("DurationSeconds") {
675 let v = ds.parse::<i64>().map_err(|_| {
676 AwsServiceError::aws_error(
677 StatusCode::BAD_REQUEST,
678 "ValidationError",
679 format!(
680 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
681 Member must be a valid integer",
682 ds
683 ),
684 )
685 })?;
686 validate_range_i64("durationSeconds", v, 900, 129600)?;
687 }
688
689 validate_optional_string_length(
691 "policy",
692 req.query_params.get("Policy").map(|s| s.as_str()),
693 1,
694 2048,
695 )?;
696 let policy = req.query_params.get("Policy").cloned();
697
698 let expiration_at = compute_expiration_at(req, DEFAULT_FEDERATION_TOKEN_DURATION)?;
700 let expiration = format_expiration(expiration_at);
701
702 let partition = partition_for_region(&req.region);
703 let creds = StsCredentials::generate();
704
705 let mut state = self.state.write();
706 let account_id = state.account_id.clone();
707 let federated_user_arn = format!(
708 "arn:{}:sts::{}:federated-user/{}",
709 partition, account_id, name
710 );
711 let federated_user_id = format!("{}:{}", account_id, name);
712
713 state.credential_identities.insert(
714 creds.access_key_id.clone(),
715 CredentialIdentity {
716 arn: federated_user_arn.clone(),
717 user_id: federated_user_id.clone(),
718 account_id: account_id.clone(),
719 },
720 );
721 state.sts_temp_credentials.insert(
722 creds.access_key_id.clone(),
723 StsTempCredential {
724 access_key_id: creds.access_key_id.clone(),
725 secret_access_key: creds.secret_access_key.clone(),
726 session_token: creds.session_token.clone(),
727 principal_arn: federated_user_arn,
728 user_id: federated_user_id,
729 account_id: account_id.clone(),
730 expiration: expiration_at,
731 },
732 );
733
734 let xml = xml_responses::get_federation_token_response(
735 &creds,
736 name,
737 &account_id,
738 partition,
739 &expiration,
740 policy.as_deref(),
741 &req.request_id,
742 );
743 Ok(AwsResponse::xml(StatusCode::OK, xml))
744 }
745
746 fn decode_authorization_message(
747 &self,
748 req: &AwsRequest,
749 ) -> Result<AwsResponse, AwsServiceError> {
750 let encoded_message = req.query_params.get("EncodedMessage").ok_or_else(|| {
751 AwsServiceError::aws_error(
752 StatusCode::BAD_REQUEST,
753 "MissingParameter",
754 "The request must contain the parameter EncodedMessage",
755 )
756 })?;
757 validate_string_length("encodedMessage", encoded_message, 1, 10240)?;
758
759 let decoded_message =
760 r#"{"allowed":true,"explicitDeny":false,"matchedStatements":{"items":[]}}"#;
761 let xml =
762 xml_responses::decode_authorization_message_response(decoded_message, &req.request_id);
763 Ok(AwsResponse::xml(StatusCode::OK, xml))
764 }
765
766 fn get_access_key_info(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
767 let access_key_id = req.query_params.get("AccessKeyId").ok_or_else(|| {
768 AwsServiceError::aws_error(
769 StatusCode::BAD_REQUEST,
770 "MissingParameter",
771 "The request must contain the parameter AccessKeyId",
772 )
773 })?;
774 validate_string_length("accessKeyId", access_key_id, 16, 128)?;
775
776 let state = self.state.read();
778 let account_id = state
779 .access_keys
780 .values()
781 .flatten()
782 .find(|k| k.access_key_id == *access_key_id)
783 .map(|_| state.account_id.clone())
784 .or_else(|| {
785 state
786 .credential_identities
787 .get(access_key_id.as_str())
788 .map(|ci| ci.account_id.clone())
789 })
790 .unwrap_or_else(|| state.account_id.clone());
791
792 let xml = xml_responses::get_access_key_info_response(&account_id, &req.request_id);
793 Ok(AwsResponse::xml(StatusCode::OK, xml))
794 }
795}
796
797fn extract_account_from_arn(arn: &str) -> Option<String> {
799 let parts: Vec<&str> = arn.split(':').collect();
800 if parts.len() >= 5 && !parts[4].is_empty() {
801 Some(parts[4].to_string())
802 } else {
803 None
804 }
805}
806
807fn extract_saml_session_name(saml_b64: &str) -> Option<String> {
809 use base64::Engine;
810 let decoded = base64::engine::general_purpose::STANDARD
811 .decode(saml_b64)
812 .ok()?;
813 let xml_str = String::from_utf8(decoded).ok()?;
814
815 let role_session_attr = "https://aws.amazon.com/SAML/Attributes/RoleSessionName";
817 let pos = xml_str.find(role_session_attr)?;
818
819 let after = &xml_str[pos..];
821 let av_start = after.find("AttributeValue")?;
822 let after_av = &after[av_start..];
823 let gt_pos = after_av.find('>')?;
825 let value_start = &after_av[gt_pos + 1..];
826 let lt_pos = value_start.find('<')?;
828 let value = value_start[..lt_pos].trim();
829
830 if value.is_empty() {
831 None
832 } else {
833 Some(value.to_string())
834 }
835}
836
837#[cfg(test)]
838mod tests {
839 use super::*;
840
841 #[test]
842 fn test_partition_for_region() {
843 assert_eq!(partition_for_region("us-east-1"), "aws");
844 assert_eq!(partition_for_region("eu-west-1"), "aws");
845 assert_eq!(partition_for_region("cn-north-1"), "aws-cn");
846 assert_eq!(partition_for_region("cn-northwest-1"), "aws-cn");
847 assert_eq!(partition_for_region("us-isob-east-1"), "aws-iso-b");
848 assert_eq!(partition_for_region("us-iso-east-1"), "aws-iso");
849 }
850
851 #[test]
852 fn test_extract_account_from_arn() {
853 assert_eq!(
854 extract_account_from_arn("arn:aws:iam::123456789012:role/test"),
855 Some("123456789012".to_string())
856 );
857 assert_eq!(
858 extract_account_from_arn("arn:aws:iam::111111111111:role/test"),
859 Some("111111111111".to_string())
860 );
861 assert_eq!(extract_account_from_arn("invalid"), None);
862 }
863
864 #[test]
865 fn test_extract_saml_session_name() {
866 use base64::Engine;
867 let xml = r#"<?xml version="1.0"?><samlp:Response><Assertion><AttributeStatement><Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><AttributeValue>testuser</AttributeValue></Attribute></AttributeStatement></Assertion></samlp:Response>"#;
868 let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
869 assert_eq!(
870 extract_saml_session_name(&encoded),
871 Some("testuser".to_string())
872 );
873 }
874
875 #[test]
876 fn test_extract_saml_session_name_with_namespace() {
877 use base64::Engine;
878 let xml = r#"<?xml version="1.0"?><samlp:Response><saml:Assertion><saml:AttributeStatement><saml:Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><saml:AttributeValue>testuser</saml:AttributeValue></saml:Attribute></saml:AttributeStatement></saml:Assertion></samlp:Response>"#;
879 let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
880 assert_eq!(
881 extract_saml_session_name(&encoded),
882 Some("testuser".to_string())
883 );
884 }
885
886 #[test]
887 fn test_session_token_format() {
888 let token = xml_responses::generate_session_token();
889 assert_eq!(token.len(), 356);
890 assert!(token.starts_with("FQoGZXIvYXdzE"));
891 }
892
893 #[test]
894 fn test_access_key_id_format() {
895 let key = xml_responses::generate_access_key_id();
896 assert_eq!(key.len(), 20);
897 assert!(key.starts_with("FSIA"));
898 }
899
900 #[test]
901 fn test_secret_access_key_format() {
902 let key = xml_responses::generate_secret_access_key();
903 assert_eq!(key.len(), 40);
904 }
905
906 #[test]
907 fn test_role_id_format() {
908 let id = xml_responses::generate_role_id();
909 assert_eq!(id.len(), 21);
910 assert!(id.starts_with("AROA"));
911 }
912
913 #[test]
914 fn test_decode_authorization_message() {
915 use crate::state::IamState;
916 use parking_lot::RwLock;
917 use std::collections::HashMap;
918 use std::sync::Arc;
919
920 let state: SharedIamState = Arc::new(RwLock::new(IamState::new("123456789012")));
921 let service = StsService::new(state);
922
923 let mut params = HashMap::new();
924 params.insert(
925 "EncodedMessage".to_string(),
926 "some-encoded-message".to_string(),
927 );
928
929 let req = make_test_request(params);
930 let resp = service.decode_authorization_message(&req).unwrap();
931 let body = std::str::from_utf8(resp.body.expect_bytes()).unwrap();
932 assert!(body.contains("DecodedMessage"));
933 assert!(body.contains("allowed"));
934 assert!(body.contains("matchedStatements"));
935 }
936
937 #[test]
938 fn test_decode_authorization_message_missing_param() {
939 use crate::state::IamState;
940 use parking_lot::RwLock;
941 use std::collections::HashMap;
942 use std::sync::Arc;
943
944 let state: SharedIamState = Arc::new(RwLock::new(IamState::new("123456789012")));
945 let service = StsService::new(state);
946
947 let req = make_test_request(HashMap::new());
948 let result = service.decode_authorization_message(&req);
949 assert!(result.is_err());
950 let err = result.err().unwrap();
951 let msg = format!("{:?}", err);
952 assert!(msg.contains("EncodedMessage"));
953 }
954
955 fn make_test_request(params: std::collections::HashMap<String, String>) -> AwsRequest {
956 AwsRequest {
957 service: "sts".into(),
958 action: "Test".into(),
959 region: "us-east-1".into(),
960 account_id: "123456789012".into(),
961 request_id: "test".into(),
962 headers: http::HeaderMap::new(),
963 query_params: params,
964 body: Default::default(),
965 path_segments: vec![],
966 raw_path: "/".into(),
967 raw_query: String::new(),
968 method: http::Method::POST,
969 is_query_protocol: true,
970 access_key_id: None,
971 principal: None,
972 }
973 }
974
975 fn parse_expiration(s: &str) -> chrono::DateTime<Utc> {
976 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ")
977 .expect("valid timestamp")
978 .and_utc()
979 }
980
981 #[test]
982 fn test_compute_expiration_with_duration() {
983 use std::collections::HashMap;
984
985 let mut params = HashMap::new();
986 params.insert("DurationSeconds".to_string(), "1800".to_string());
987 let req = make_test_request(params);
988
989 let now = Utc::now();
990 let exp_str = compute_expiration(&req, 3600).unwrap();
991 let exp_utc = parse_expiration(&exp_str);
992
993 let diff = (exp_utc - now).num_seconds();
995 assert!(
996 (1798..=1802).contains(&diff),
997 "expected ~1800s duration, got {diff}s"
998 );
999 }
1000
1001 #[test]
1002 fn test_compute_expiration_default() {
1003 use std::collections::HashMap;
1004
1005 let req = make_test_request(HashMap::new());
1006
1007 let now = Utc::now();
1008 let exp_str = compute_expiration(&req, 43200).unwrap();
1009 let exp_utc = parse_expiration(&exp_str);
1010
1011 let diff = (exp_utc - now).num_seconds();
1013 assert!(
1014 (43198..=43202).contains(&diff),
1015 "expected ~43200s duration, got {diff}s"
1016 );
1017 }
1018
1019 #[test]
1020 fn test_compute_expiration_uses_provided_not_default() {
1021 use std::collections::HashMap;
1022
1023 let mut params = HashMap::new();
1024 params.insert("DurationSeconds".to_string(), "900".to_string());
1025 let req = make_test_request(params);
1026
1027 let before = Utc::now();
1028 let exp_str = compute_expiration(&req, 43200).unwrap();
1029 let exp_utc = parse_expiration(&exp_str);
1030
1031 let expected = before + chrono::Duration::seconds(900);
1033 let diff = (exp_utc - expected).num_seconds().abs();
1034 assert!(
1035 diff <= 2,
1036 "expected ~900s duration, got diff={diff}s from expected"
1037 );
1038 }
1039}