1use async_trait::async_trait;
2use chrono::Utc;
3use http::StatusCode;
4
5use fakecloud_core::service::{AwsRequest, AwsResponse, AwsService, AwsServiceError};
6use fakecloud_core::validation::*;
7
8use crate::state::{CredentialIdentity, SharedIamState};
9use crate::xml_responses::{self, StsCredentials};
10
11const DEFAULT_ASSUME_ROLE_DURATION: i64 = 3600;
13
14const DEFAULT_SESSION_TOKEN_DURATION: i64 = 43200;
16
17const DEFAULT_FEDERATION_TOKEN_DURATION: i64 = 43200;
19
20fn compute_expiration(req: &AwsRequest, default_duration: i64) -> Result<String, AwsServiceError> {
22 let duration = if let Some(ds) = req.query_params.get("DurationSeconds") {
23 ds.parse::<i64>().map_err(|_| {
24 AwsServiceError::aws_error(
25 StatusCode::BAD_REQUEST,
26 "ValidationError",
27 format!(
28 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
29 Member must be a valid integer",
30 ds
31 ),
32 )
33 })?
34 } else {
35 default_duration
36 };
37 let expiration = Utc::now() + chrono::Duration::seconds(duration);
38 Ok(expiration.format("%Y-%m-%dT%H:%M:%SZ").to_string())
39}
40
41pub struct StsService {
42 state: SharedIamState,
43}
44
45impl StsService {
46 pub fn new(state: SharedIamState) -> Self {
47 Self { state }
48 }
49}
50
51#[async_trait]
52impl AwsService for StsService {
53 fn service_name(&self) -> &str {
54 "sts"
55 }
56
57 async fn handle(&self, req: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
58 match req.action.as_str() {
59 "GetCallerIdentity" => self.get_caller_identity(&req),
60 "AssumeRole" => self.assume_role(&req),
61 "AssumeRoleWithWebIdentity" => self.assume_role_with_web_identity(&req),
62 "AssumeRoleWithSAML" => self.assume_role_with_saml(&req),
63 "GetSessionToken" => self.get_session_token(&req),
64 "GetFederationToken" => self.get_federation_token(&req),
65 "GetAccessKeyInfo" => self.get_access_key_info(&req),
66 "DecodeAuthorizationMessage" => self.decode_authorization_message(&req),
67 _ => Err(AwsServiceError::action_not_implemented("sts", &req.action)),
68 }
69 }
70
71 fn supported_actions(&self) -> &[&str] {
72 &[
73 "GetCallerIdentity",
74 "AssumeRole",
75 "AssumeRoleWithWebIdentity",
76 "AssumeRoleWithSAML",
77 "GetSessionToken",
78 "GetFederationToken",
79 "GetAccessKeyInfo",
80 "DecodeAuthorizationMessage",
81 ]
82 }
83}
84
85fn partition_for_region(region: &str) -> &str {
87 if region.starts_with("cn-") {
88 "aws-cn"
89 } else if region.starts_with("us-iso-") {
90 "aws-iso"
91 } else if region.starts_with("us-isob-") {
92 "aws-iso-b"
93 } else if region.starts_with("us-isof-") {
94 "aws-iso-f"
95 } else if region.starts_with("eu-isoe-") {
96 "aws-iso-e"
97 } else {
98 "aws"
99 }
100}
101
102fn extract_access_key(req: &AwsRequest) -> Option<String> {
104 let auth = req.headers.get("authorization")?.to_str().ok()?;
105 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
106 Some(info.access_key)
107}
108
109impl StsService {
110 fn get_caller_identity(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
111 let state = self.state.read();
112 let partition = partition_for_region(&req.region);
113
114 if let Some(access_key) = extract_access_key(req) {
116 if let Some(identity) = state.credential_identities.get(&access_key) {
118 let xml = xml_responses::get_caller_identity_response(
119 &identity.account_id,
120 &identity.arn,
121 &identity.user_id,
122 &req.request_id,
123 );
124 return Ok(AwsResponse::xml(StatusCode::OK, xml));
125 }
126
127 for keys in state.access_keys.values() {
129 for key in keys {
130 if key.access_key_id == access_key {
131 if let Some(user) = state.users.get(&key.user_name) {
132 let xml = xml_responses::get_caller_identity_response(
133 &state.account_id,
134 &user.arn,
135 &user.user_id,
136 &req.request_id,
137 );
138 return Ok(AwsResponse::xml(StatusCode::OK, xml));
139 }
140 }
141 }
142 }
143 }
144
145 let arn = format!("arn:{}:iam::{}:root", partition, state.account_id);
147 let user_id = "FKIAIOSFODNN7EXAMPLE";
148 let xml = xml_responses::get_caller_identity_response(
149 &state.account_id,
150 &arn,
151 user_id,
152 &req.request_id,
153 );
154 Ok(AwsResponse::xml(StatusCode::OK, xml))
155 }
156
157 fn assume_role(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
158 let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
159 AwsServiceError::aws_error(
160 StatusCode::BAD_REQUEST,
161 "MissingParameter",
162 "The request must contain the parameter RoleArn",
163 )
164 })?;
165 validate_string_length("roleArn", role_arn, 20, 2048)?;
166
167 let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
168 AwsServiceError::aws_error(
169 StatusCode::BAD_REQUEST,
170 "MissingParameter",
171 "The request must contain the parameter RoleSessionName",
172 )
173 })?;
174 validate_string_length("roleSessionName", role_session_name, 2, 64)?;
175
176 if let Some(ds) = req.query_params.get("DurationSeconds") {
178 let v = ds.parse::<i64>().map_err(|_| {
179 AwsServiceError::aws_error(
180 StatusCode::BAD_REQUEST,
181 "ValidationError",
182 format!(
183 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
184 Member must be a valid integer",
185 ds
186 ),
187 )
188 })?;
189 validate_range_i64("durationSeconds", v, 900, 43200)?;
190 }
191
192 validate_optional_string_length(
194 "externalId",
195 req.query_params.get("ExternalId").map(|s| s.as_str()),
196 2,
197 1224,
198 )?;
199
200 validate_optional_string_length(
202 "policy",
203 req.query_params.get("Policy").map(|s| s.as_str()),
204 1,
205 2048,
206 )?;
207
208 validate_optional_string_length(
210 "sourceIdentity",
211 req.query_params.get("SourceIdentity").map(|s| s.as_str()),
212 2,
213 64,
214 )?;
215
216 validate_optional_string_length(
218 "serialNumber",
219 req.query_params.get("SerialNumber").map(|s| s.as_str()),
220 9,
221 256,
222 )?;
223 let serial_number = req.query_params.get("SerialNumber").cloned();
224
225 validate_optional_string_length(
227 "tokenCode",
228 req.query_params.get("TokenCode").map(|s| s.as_str()),
229 6,
230 6,
231 )?;
232 let token_code = req.query_params.get("TokenCode").cloned();
233
234 let expiration = compute_expiration(req, DEFAULT_ASSUME_ROLE_DURATION)?;
236
237 let _mfa_serial = serial_number;
239 let _mfa_token = token_code;
240
241 let partition = partition_for_region(&req.region);
242 let creds = StsCredentials::generate();
243
244 let mut state = self.state.write();
245
246 let account_id =
248 extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
249
250 let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
252 let role_id = state
253 .roles
254 .get(role_name)
255 .map(|r| r.role_id.clone())
256 .unwrap_or_else(xml_responses::generate_role_id);
257
258 let assumed_role_arn = format!(
259 "arn:{}:sts::{}:assumed-role/{}/{}",
260 partition, account_id, role_name, role_session_name
261 );
262 let assumed_role_id = format!("{}:{}", role_id, role_session_name);
263
264 state.credential_identities.insert(
266 creds.access_key_id.clone(),
267 CredentialIdentity {
268 arn: assumed_role_arn,
269 user_id: assumed_role_id,
270 account_id: account_id.clone(),
271 },
272 );
273
274 let xml = xml_responses::assume_role_response(
275 role_arn,
276 role_session_name,
277 &role_id,
278 &account_id,
279 partition,
280 &creds,
281 &expiration,
282 &req.request_id,
283 );
284 Ok(AwsResponse::xml(StatusCode::OK, xml))
285 }
286
287 fn assume_role_with_web_identity(
288 &self,
289 req: &AwsRequest,
290 ) -> Result<AwsResponse, AwsServiceError> {
291 let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
292 AwsServiceError::aws_error(
293 StatusCode::BAD_REQUEST,
294 "MissingParameter",
295 "The request must contain the parameter RoleArn",
296 )
297 })?;
298 validate_string_length("roleArn", role_arn, 20, 2048)?;
299
300 let role_session_name = req.query_params.get("RoleSessionName").ok_or_else(|| {
301 AwsServiceError::aws_error(
302 StatusCode::BAD_REQUEST,
303 "MissingParameter",
304 "The request must contain the parameter RoleSessionName",
305 )
306 })?;
307 validate_string_length("roleSessionName", role_session_name, 2, 64)?;
308
309 let web_identity_token = req.query_params.get("WebIdentityToken").ok_or_else(|| {
311 AwsServiceError::aws_error(
312 StatusCode::BAD_REQUEST,
313 "MissingParameter",
314 "The request must contain the parameter WebIdentityToken",
315 )
316 })?;
317 validate_string_length("webIdentityToken", web_identity_token, 4, 20000)?;
318 let _web_identity_token = web_identity_token.clone();
319
320 validate_optional_string_length(
322 "policy",
323 req.query_params.get("Policy").map(|s| s.as_str()),
324 1,
325 2048,
326 )?;
327
328 validate_optional_string_length(
330 "providerId",
331 req.query_params.get("ProviderId").map(|s| s.as_str()),
332 4,
333 2048,
334 )?;
335
336 if let Some(ds) = req.query_params.get("DurationSeconds") {
338 let v = ds.parse::<i64>().map_err(|_| {
339 AwsServiceError::aws_error(
340 StatusCode::BAD_REQUEST,
341 "ValidationError",
342 format!(
343 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
344 Member must be a valid integer",
345 ds
346 ),
347 )
348 })?;
349 validate_range_i64("durationSeconds", v, 900, 43200)?;
350 }
351
352 let expiration = compute_expiration(req, DEFAULT_ASSUME_ROLE_DURATION)?;
354
355 let partition = partition_for_region(&req.region);
356 let creds = StsCredentials::generate();
357 let role_id = xml_responses::generate_role_id();
358
359 let mut state = self.state.write();
360 let account_id =
361 extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
362
363 let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
364 let assumed_role_arn = format!(
365 "arn:{}:sts::{}:assumed-role/{}/{}",
366 partition, account_id, role_name, role_session_name
367 );
368 let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
369
370 state.credential_identities.insert(
371 creds.access_key_id.clone(),
372 CredentialIdentity {
373 arn: assumed_role_arn,
374 user_id: assumed_role_id_str,
375 account_id: account_id.clone(),
376 },
377 );
378
379 let xml = xml_responses::assume_role_with_web_identity_response(
380 role_arn,
381 role_session_name,
382 &account_id,
383 partition,
384 &creds,
385 &role_id,
386 &expiration,
387 &req.request_id,
388 );
389 Ok(AwsResponse::xml(StatusCode::OK, xml))
390 }
391
392 fn assume_role_with_saml(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
393 let role_arn = req.query_params.get("RoleArn").ok_or_else(|| {
394 AwsServiceError::aws_error(
395 StatusCode::BAD_REQUEST,
396 "MissingParameter",
397 "The request must contain the parameter RoleArn",
398 )
399 })?;
400 validate_string_length("roleArn", role_arn, 20, 2048)?;
401
402 let principal_arn = req.query_params.get("PrincipalArn").ok_or_else(|| {
404 AwsServiceError::aws_error(
405 StatusCode::BAD_REQUEST,
406 "MissingParameter",
407 "The request must contain the parameter PrincipalArn",
408 )
409 })?;
410 validate_string_length("principalArn", principal_arn, 20, 2048)?;
411 let _principal_arn = principal_arn.clone();
412
413 let saml_assertion = req.query_params.get("SAMLAssertion").ok_or_else(|| {
415 AwsServiceError::aws_error(
416 StatusCode::BAD_REQUEST,
417 "MissingParameter",
418 "The request must contain the parameter SAMLAssertion",
419 )
420 })?;
421 validate_string_length("sAMLAssertion", saml_assertion, 4, 100000)?;
422
423 validate_optional_string_length(
425 "policy",
426 req.query_params.get("Policy").map(|s| s.as_str()),
427 1,
428 2048,
429 )?;
430
431 if let Some(ds) = req.query_params.get("DurationSeconds") {
433 let v = ds.parse::<i64>().map_err(|_| {
434 AwsServiceError::aws_error(
435 StatusCode::BAD_REQUEST,
436 "ValidationError",
437 format!(
438 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
439 Member must be a valid integer",
440 ds
441 ),
442 )
443 })?;
444 validate_range_i64("durationSeconds", v, 900, 43200)?;
445 }
446
447 let expiration = compute_expiration(req, DEFAULT_ASSUME_ROLE_DURATION)?;
449
450 let role_session_name =
452 extract_saml_session_name(saml_assertion).unwrap_or_else(|| "saml-session".to_string());
453
454 let partition = partition_for_region(&req.region);
455 let creds = StsCredentials::generate();
456 let role_id = xml_responses::generate_role_id();
457
458 let mut state = self.state.write();
459 let account_id =
460 extract_account_from_arn(role_arn).unwrap_or_else(|| state.account_id.clone());
461
462 let role_name = role_arn.rsplit('/').next().unwrap_or("unknown");
463 let assumed_role_arn = format!(
464 "arn:{}:sts::{}:assumed-role/{}/{}",
465 partition, account_id, role_name, &role_session_name
466 );
467 let assumed_role_id_str = format!("{}:{}", role_id, role_session_name);
468
469 state.credential_identities.insert(
470 creds.access_key_id.clone(),
471 CredentialIdentity {
472 arn: assumed_role_arn,
473 user_id: assumed_role_id_str,
474 account_id: account_id.clone(),
475 },
476 );
477
478 let xml = xml_responses::assume_role_with_saml_response(
479 role_arn,
480 &role_session_name,
481 &account_id,
482 partition,
483 &creds,
484 &role_id,
485 &expiration,
486 &req.request_id,
487 );
488 Ok(AwsResponse::xml(StatusCode::OK, xml))
489 }
490
491 fn get_session_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
492 if let Some(ds) = req.query_params.get("DurationSeconds") {
494 let v = ds.parse::<i64>().map_err(|_| {
495 AwsServiceError::aws_error(
496 StatusCode::BAD_REQUEST,
497 "ValidationError",
498 format!(
499 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
500 Member must be a valid integer",
501 ds
502 ),
503 )
504 })?;
505 validate_range_i64("durationSeconds", v, 900, 129600)?;
506 }
507
508 validate_optional_string_length(
510 "serialNumber",
511 req.query_params.get("SerialNumber").map(|s| s.as_str()),
512 9,
513 256,
514 )?;
515 let _serial_number = req.query_params.get("SerialNumber").cloned();
516
517 validate_optional_string_length(
519 "tokenCode",
520 req.query_params.get("TokenCode").map(|s| s.as_str()),
521 6,
522 6,
523 )?;
524 let _token_code = req.query_params.get("TokenCode").cloned();
525
526 let expiration = compute_expiration(req, DEFAULT_SESSION_TOKEN_DURATION)?;
528
529 let xml = xml_responses::get_session_token_response(&expiration, &req.request_id);
530 Ok(AwsResponse::xml(StatusCode::OK, xml))
531 }
532
533 fn get_federation_token(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
534 let name = req.query_params.get("Name").ok_or_else(|| {
535 AwsServiceError::aws_error(
536 StatusCode::BAD_REQUEST,
537 "MissingParameter",
538 "The request must contain the parameter Name",
539 )
540 })?;
541 validate_string_length("name", name, 2, 32)?;
542
543 if let Some(ds) = req.query_params.get("DurationSeconds") {
545 let v = ds.parse::<i64>().map_err(|_| {
546 AwsServiceError::aws_error(
547 StatusCode::BAD_REQUEST,
548 "ValidationError",
549 format!(
550 "Value '{}' at 'durationSeconds' failed to satisfy constraint: \
551 Member must be a valid integer",
552 ds
553 ),
554 )
555 })?;
556 validate_range_i64("durationSeconds", v, 900, 129600)?;
557 }
558
559 validate_optional_string_length(
561 "policy",
562 req.query_params.get("Policy").map(|s| s.as_str()),
563 1,
564 2048,
565 )?;
566 let policy = req.query_params.get("Policy").cloned();
567
568 let expiration = compute_expiration(req, DEFAULT_FEDERATION_TOKEN_DURATION)?;
570
571 let partition = partition_for_region(&req.region);
572 let state = self.state.read();
573 let xml = xml_responses::get_federation_token_response(
574 name,
575 &state.account_id,
576 partition,
577 &expiration,
578 policy.as_deref(),
579 &req.request_id,
580 );
581 Ok(AwsResponse::xml(StatusCode::OK, xml))
582 }
583
584 fn decode_authorization_message(
585 &self,
586 req: &AwsRequest,
587 ) -> Result<AwsResponse, AwsServiceError> {
588 let _encoded_message = req.query_params.get("EncodedMessage").ok_or_else(|| {
589 AwsServiceError::aws_error(
590 StatusCode::BAD_REQUEST,
591 "MissingParameter",
592 "The request must contain the parameter EncodedMessage",
593 )
594 })?;
595
596 let decoded_message =
597 r#"{"allowed":true,"explicitDeny":false,"matchedStatements":{"items":[]}}"#;
598 let xml =
599 xml_responses::decode_authorization_message_response(decoded_message, &req.request_id);
600 Ok(AwsResponse::xml(StatusCode::OK, xml))
601 }
602
603 fn get_access_key_info(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
604 let access_key_id = req.query_params.get("AccessKeyId").ok_or_else(|| {
605 AwsServiceError::aws_error(
606 StatusCode::BAD_REQUEST,
607 "MissingParameter",
608 "The request must contain the parameter AccessKeyId",
609 )
610 })?;
611 validate_string_length("accessKeyId", access_key_id, 16, 128)?;
612
613 let state = self.state.read();
615 let account_id = state
616 .access_keys
617 .values()
618 .flatten()
619 .find(|k| k.access_key_id == *access_key_id)
620 .map(|_| state.account_id.clone())
621 .or_else(|| {
622 state
623 .credential_identities
624 .get(access_key_id.as_str())
625 .map(|ci| ci.account_id.clone())
626 })
627 .unwrap_or_else(|| state.account_id.clone());
628
629 let xml = xml_responses::get_access_key_info_response(&account_id, &req.request_id);
630 Ok(AwsResponse::xml(StatusCode::OK, xml))
631 }
632}
633
634fn extract_account_from_arn(arn: &str) -> Option<String> {
636 let parts: Vec<&str> = arn.split(':').collect();
637 if parts.len() >= 5 && !parts[4].is_empty() {
638 Some(parts[4].to_string())
639 } else {
640 None
641 }
642}
643
644fn extract_saml_session_name(saml_b64: &str) -> Option<String> {
646 use base64::Engine;
647 let decoded = base64::engine::general_purpose::STANDARD
648 .decode(saml_b64)
649 .ok()?;
650 let xml_str = String::from_utf8(decoded).ok()?;
651
652 let role_session_attr = "https://aws.amazon.com/SAML/Attributes/RoleSessionName";
654 let pos = xml_str.find(role_session_attr)?;
655
656 let after = &xml_str[pos..];
658 let av_start = after.find("AttributeValue")?;
659 let after_av = &after[av_start..];
660 let gt_pos = after_av.find('>')?;
662 let value_start = &after_av[gt_pos + 1..];
663 let lt_pos = value_start.find('<')?;
665 let value = value_start[..lt_pos].trim();
666
667 if value.is_empty() {
668 None
669 } else {
670 Some(value.to_string())
671 }
672}
673
674#[cfg(test)]
675mod tests {
676 use super::*;
677
678 #[test]
679 fn test_partition_for_region() {
680 assert_eq!(partition_for_region("us-east-1"), "aws");
681 assert_eq!(partition_for_region("eu-west-1"), "aws");
682 assert_eq!(partition_for_region("cn-north-1"), "aws-cn");
683 assert_eq!(partition_for_region("cn-northwest-1"), "aws-cn");
684 assert_eq!(partition_for_region("us-isob-east-1"), "aws-iso-b");
685 assert_eq!(partition_for_region("us-iso-east-1"), "aws-iso");
686 }
687
688 #[test]
689 fn test_extract_account_from_arn() {
690 assert_eq!(
691 extract_account_from_arn("arn:aws:iam::123456789012:role/test"),
692 Some("123456789012".to_string())
693 );
694 assert_eq!(
695 extract_account_from_arn("arn:aws:iam::111111111111:role/test"),
696 Some("111111111111".to_string())
697 );
698 assert_eq!(extract_account_from_arn("invalid"), None);
699 }
700
701 #[test]
702 fn test_extract_saml_session_name() {
703 use base64::Engine;
704 let xml = r#"<?xml version="1.0"?><samlp:Response><Assertion><AttributeStatement><Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><AttributeValue>testuser</AttributeValue></Attribute></AttributeStatement></Assertion></samlp:Response>"#;
705 let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
706 assert_eq!(
707 extract_saml_session_name(&encoded),
708 Some("testuser".to_string())
709 );
710 }
711
712 #[test]
713 fn test_extract_saml_session_name_with_namespace() {
714 use base64::Engine;
715 let xml = r#"<?xml version="1.0"?><samlp:Response><saml:Assertion><saml:AttributeStatement><saml:Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName"><saml:AttributeValue>testuser</saml:AttributeValue></saml:Attribute></saml:AttributeStatement></saml:Assertion></samlp:Response>"#;
716 let encoded = base64::engine::general_purpose::STANDARD.encode(xml.as_bytes());
717 assert_eq!(
718 extract_saml_session_name(&encoded),
719 Some("testuser".to_string())
720 );
721 }
722
723 #[test]
724 fn test_session_token_format() {
725 let token = xml_responses::generate_session_token();
726 assert_eq!(token.len(), 356);
727 assert!(token.starts_with("FQoGZXIvYXdzE"));
728 }
729
730 #[test]
731 fn test_access_key_id_format() {
732 let key = xml_responses::generate_access_key_id();
733 assert_eq!(key.len(), 20);
734 assert!(key.starts_with("FSIA"));
735 }
736
737 #[test]
738 fn test_secret_access_key_format() {
739 let key = xml_responses::generate_secret_access_key();
740 assert_eq!(key.len(), 40);
741 }
742
743 #[test]
744 fn test_role_id_format() {
745 let id = xml_responses::generate_role_id();
746 assert_eq!(id.len(), 21);
747 assert!(id.starts_with("AROA"));
748 }
749
750 #[test]
751 fn test_decode_authorization_message() {
752 use crate::state::IamState;
753 use parking_lot::RwLock;
754 use std::collections::HashMap;
755 use std::sync::Arc;
756
757 let state: SharedIamState = Arc::new(RwLock::new(IamState::new("123456789012")));
758 let service = StsService::new(state);
759
760 let mut params = HashMap::new();
761 params.insert(
762 "EncodedMessage".to_string(),
763 "some-encoded-message".to_string(),
764 );
765
766 let req = make_test_request(params);
767 let resp = service.decode_authorization_message(&req).unwrap();
768 let body = std::str::from_utf8(&resp.body).unwrap();
769 assert!(body.contains("DecodedMessage"));
770 assert!(body.contains("allowed"));
771 assert!(body.contains("matchedStatements"));
772 }
773
774 #[test]
775 fn test_decode_authorization_message_missing_param() {
776 use crate::state::IamState;
777 use parking_lot::RwLock;
778 use std::collections::HashMap;
779 use std::sync::Arc;
780
781 let state: SharedIamState = Arc::new(RwLock::new(IamState::new("123456789012")));
782 let service = StsService::new(state);
783
784 let req = make_test_request(HashMap::new());
785 let result = service.decode_authorization_message(&req);
786 assert!(result.is_err());
787 let err = result.err().unwrap();
788 let msg = format!("{:?}", err);
789 assert!(msg.contains("EncodedMessage"));
790 }
791
792 fn make_test_request(params: std::collections::HashMap<String, String>) -> AwsRequest {
793 AwsRequest {
794 service: "sts".into(),
795 action: "Test".into(),
796 region: "us-east-1".into(),
797 account_id: "123456789012".into(),
798 request_id: "test".into(),
799 headers: http::HeaderMap::new(),
800 query_params: params,
801 body: Default::default(),
802 path_segments: vec![],
803 raw_path: "/".into(),
804 method: http::Method::POST,
805 is_query_protocol: true,
806 access_key_id: None,
807 }
808 }
809
810 fn parse_expiration(s: &str) -> chrono::DateTime<Utc> {
811 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%SZ")
812 .expect("valid timestamp")
813 .and_utc()
814 }
815
816 #[test]
817 fn test_compute_expiration_with_duration() {
818 use std::collections::HashMap;
819
820 let mut params = HashMap::new();
821 params.insert("DurationSeconds".to_string(), "1800".to_string());
822 let req = make_test_request(params);
823
824 let now = Utc::now();
825 let exp_str = compute_expiration(&req, 3600).unwrap();
826 let exp_utc = parse_expiration(&exp_str);
827
828 let diff = (exp_utc - now).num_seconds();
830 assert!(
831 (1798..=1802).contains(&diff),
832 "expected ~1800s duration, got {diff}s"
833 );
834 }
835
836 #[test]
837 fn test_compute_expiration_default() {
838 use std::collections::HashMap;
839
840 let req = make_test_request(HashMap::new());
841
842 let now = Utc::now();
843 let exp_str = compute_expiration(&req, 43200).unwrap();
844 let exp_utc = parse_expiration(&exp_str);
845
846 let diff = (exp_utc - now).num_seconds();
848 assert!(
849 (43198..=43202).contains(&diff),
850 "expected ~43200s duration, got {diff}s"
851 );
852 }
853
854 #[test]
855 fn test_compute_expiration_uses_provided_not_default() {
856 use std::collections::HashMap;
857
858 let mut params = HashMap::new();
859 params.insert("DurationSeconds".to_string(), "900".to_string());
860 let req = make_test_request(params);
861
862 let before = Utc::now();
863 let exp_str = compute_expiration(&req, 43200).unwrap();
864 let exp_utc = parse_expiration(&exp_str);
865
866 let expected = before + chrono::Duration::seconds(900);
868 let diff = (exp_utc - expected).num_seconds().abs();
869 assert!(
870 diff <= 2,
871 "expected ~900s duration, got diff={diff}s from expected"
872 );
873 }
874}