1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8 Query,
11 Json,
14 Rest,
17 RestJson,
20}
21
22const REST_XML_SERVICES: &[&str] = &["s3", "cloudfront", "route53"];
24
25const REST_JSON_SERVICES: &[&str] = &[
27 "lambda",
28 "ses",
29 "apigateway",
30 "bedrock",
31 "bedrock-agent",
32 "bedrock-agent-runtime",
33 "scheduler",
34];
35
36#[derive(Debug, Clone)]
38pub struct DetectedRequest {
39 pub service: String,
40 pub action: String,
41 pub protocol: AwsProtocol,
42}
43
44pub fn detect_service_headers_only(
51 headers: &HeaderMap,
52 query_params: &HashMap<String, String>,
53) -> Option<DetectedRequest> {
54 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
56 return parse_amz_target(target);
57 }
58 if let Some(action) = query_params.get("Action") {
59 let service = extract_service_from_auth(headers)
60 .or_else(|| infer_service_from_action(action))
61 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
62 if let Some(service) = service {
63 return Some(DetectedRequest {
64 service,
65 action: action.clone(),
66 protocol: AwsProtocol::Query,
67 });
68 }
69 }
70 if let Some(service) = extract_service_from_auth(headers) {
71 if let Some(protocol) = rest_protocol_for(&service) {
72 return Some(DetectedRequest {
73 service,
74 action: String::new(),
75 protocol,
76 });
77 }
78 }
79 if let Some(credential) = query_params.get("X-Amz-Credential") {
80 let parts: Vec<&str> = credential.split('/').collect();
81 if parts.len() >= 4 {
82 let service = normalize_service_name(parts[3]).to_string();
83 if let Some(protocol) = rest_protocol_for(&service) {
84 return Some(DetectedRequest {
85 service,
86 action: String::new(),
87 protocol,
88 });
89 }
90 }
91 }
92 if query_params.contains_key("AWSAccessKeyId")
93 && query_params.contains_key("Signature")
94 && query_params.contains_key("Expires")
95 {
96 return Some(DetectedRequest {
97 service: "s3".to_string(),
98 action: String::new(),
99 protocol: AwsProtocol::Rest,
100 });
101 }
102 if let Some(host_info) = parse_routing_host_from_headers(headers) {
103 if let Some(protocol) = rest_protocol_for(&host_info.service) {
104 return Some(DetectedRequest {
105 service: host_info.service,
106 action: String::new(),
107 protocol,
108 });
109 }
110 }
111 None
112}
113
114pub fn detect_service(
116 headers: &HeaderMap,
117 query_params: &HashMap<String, String>,
118 body: &Bytes,
119) -> Option<DetectedRequest> {
120 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
122 return parse_amz_target(target);
123 }
124
125 if let Some(action) = query_params.get("Action") {
127 let service = extract_service_from_auth(headers)
128 .or_else(|| infer_service_from_action(action))
129 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
130 if let Some(service) = service {
131 return Some(DetectedRequest {
132 service,
133 action: action.clone(),
134 protocol: AwsProtocol::Query,
135 });
136 }
137 }
138
139 {
141 let form_params = decode_form_urlencoded(body);
142
143 if let Some(action) = form_params.get("Action") {
144 let service = extract_service_from_auth(headers)
145 .or_else(|| infer_service_from_action(action))
146 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
147 if let Some(service) = service {
148 return Some(DetectedRequest {
149 service,
150 action: action.clone(),
151 protocol: AwsProtocol::Query,
152 });
153 }
154 }
155 }
156
157 if let Some(service) = extract_service_from_auth(headers) {
159 if let Some(protocol) = rest_protocol_for(&service) {
160 return Some(DetectedRequest {
161 service,
162 action: String::new(), protocol,
164 });
165 }
166 }
167
168 if let Some(credential) = query_params.get("X-Amz-Credential") {
170 let parts: Vec<&str> = credential.split('/').collect();
172 if parts.len() >= 4 {
173 let service = normalize_service_name(parts[3]).to_string();
174 if let Some(protocol) = rest_protocol_for(&service) {
175 return Some(DetectedRequest {
176 service,
177 action: String::new(),
178 protocol,
179 });
180 }
181 }
182 }
183
184 if query_params.contains_key("AWSAccessKeyId")
188 && query_params.contains_key("Signature")
189 && query_params.contains_key("Expires")
190 {
191 return Some(DetectedRequest {
192 service: "s3".to_string(),
193 action: String::new(),
194 protocol: AwsProtocol::Rest,
195 });
196 }
197
198 if let Some(host_info) = parse_routing_host_from_headers(headers) {
202 if let Some(protocol) = rest_protocol_for(&host_info.service) {
203 return Some(DetectedRequest {
204 service: host_info.service,
205 action: String::new(),
206 protocol,
207 });
208 }
209 }
210
211 None
212}
213
214#[derive(Debug, Clone, PartialEq, Eq)]
223pub struct RoutingHost {
224 pub service: String,
225 pub region: String,
226 pub bucket: Option<String>,
228}
229
230const LOCALSTACK_SUFFIX: &str = ".localhost.localstack.cloud";
231const AWS_SUFFIX: &str = ".amazonaws.com";
232
233pub fn parse_routing_host(host: &str) -> Option<RoutingHost> {
237 let hostname = host.split(':').next()?;
238 if hostname.is_empty() {
239 return None;
240 }
241 let hostname = hostname.to_ascii_lowercase();
242 if let Some(prefix) = hostname.strip_suffix(LOCALSTACK_SUFFIX) {
243 return parse_localstack_prefix(prefix);
244 }
245 if hostname == "amazonaws.com" {
246 return None;
247 }
248 if let Some(prefix) = hostname.strip_suffix(AWS_SUFFIX) {
249 return parse_aws_prefix(prefix);
250 }
251 None
252}
253
254pub fn parse_routing_host_from_headers(headers: &HeaderMap) -> Option<RoutingHost> {
256 let host = headers.get("host")?.to_str().ok()?;
257 parse_routing_host(host)
258}
259
260fn parse_localstack_prefix(prefix: &str) -> Option<RoutingHost> {
261 if prefix.is_empty() {
262 return None;
263 }
264 let labels: Vec<&str> = prefix.split('.').collect();
265 if labels.iter().any(|l| l.is_empty()) {
266 return None;
267 }
268 match labels.len() {
269 2 => Some(RoutingHost {
270 service: labels[0].to_string(),
271 region: labels[1].to_string(),
272 bucket: None,
273 }),
274 n if n >= 3 && labels[n - 2] == "s3" => {
275 let bucket = labels[..n - 2].join(".");
276 Some(RoutingHost {
277 service: "s3".to_string(),
278 region: labels[n - 1].to_string(),
279 bucket: Some(bucket),
280 })
281 }
282 n if n >= 3 && labels[n - 2] == "s3-accesspoint" => {
283 let bucket = labels[..n - 2].join(".");
284 Some(RoutingHost {
285 service: "s3".to_string(),
286 region: labels[n - 1].to_string(),
287 bucket: Some(bucket),
288 })
289 }
290 n if n >= 3 && labels[n - 2] == "s3-control" => Some(RoutingHost {
291 service: "s3".to_string(),
292 region: labels[n - 1].to_string(),
293 bucket: None,
294 }),
295 _ => None,
296 }
297}
298
299fn parse_aws_prefix(prefix: &str) -> Option<RoutingHost> {
311 if prefix.is_empty() {
312 return None;
313 }
314 let labels: Vec<&str> = prefix.split('.').collect();
315 if labels.iter().any(|l| l.is_empty()) {
316 return None;
317 }
318 let last = *labels.last()?;
319
320 if let Some(region) = last.strip_prefix("s3-") {
323 if !region.is_empty() {
324 let bucket = if labels.len() >= 2 {
325 Some(labels[..labels.len() - 1].join("."))
326 } else {
327 None
328 };
329 return Some(RoutingHost {
330 service: "s3".to_string(),
331 region: region.to_string(),
332 bucket,
333 });
334 }
335 }
336
337 if last == "s3" {
341 if labels.len() == 1 {
342 return Some(RoutingHost {
343 service: "s3".to_string(),
344 region: "us-east-1".to_string(),
345 bucket: None,
346 });
347 }
348 return Some(RoutingHost {
349 service: "s3".to_string(),
350 region: "us-east-1".to_string(),
351 bucket: Some(labels[..labels.len() - 1].join(".")),
352 });
353 }
354
355 if last == "s3-accesspoint" {
358 if labels.len() == 2 {
359 return Some(RoutingHost {
360 service: "s3".to_string(),
361 region: labels[0].to_string(),
362 bucket: None,
363 });
364 }
365 let bucket = labels[..labels.len() - 2].join(".");
366 return Some(RoutingHost {
367 service: "s3".to_string(),
368 region: labels[labels.len() - 1].to_string(),
369 bucket: Some(bucket),
370 });
371 }
372
373 if labels.len() >= 2 && labels[labels.len() - 2] == "s3-control" {
376 return Some(RoutingHost {
377 service: "s3".to_string(),
378 region: last.to_string(),
379 bucket: None,
380 });
381 }
382
383 match labels.len() {
384 2 => Some(RoutingHost {
387 service: labels[0].to_string(),
388 region: labels[1].to_string(),
389 bucket: None,
390 }),
391 n if n >= 3 && labels[n - 2] == "s3" => {
393 let bucket = labels[..n - 2].join(".");
394 Some(RoutingHost {
395 service: "s3".to_string(),
396 region: labels[n - 1].to_string(),
397 bucket: Some(bucket),
398 })
399 }
400 _ => None,
401 }
402}
403
404fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
407 let (prefix, action) = target.rsplit_once('.')?;
408
409 let service = match prefix {
410 "AWSEvents" => "events",
411 "AmazonSSM" => "ssm",
412 "AmazonSQS" => "sqs",
413 "AmazonSNS" => "sns",
414 "DynamoDB_20120810" => "dynamodb",
415 "DynamoDBStreams_20120810" => "dynamodbstreams",
416 "Logs_20140328" => "logs",
417 s if s.starts_with("secretsmanager") => "secretsmanager",
418 s if s.starts_with("TrentService") => "kms",
419 s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
420 s if s.starts_with("AWSCognitoIdentityService") => "cognito-identity",
421 s if s.starts_with("Kinesis_20131202") => "kinesis",
422 s if s.starts_with("AmazonEC2ContainerRegistry_V") => "ecr",
423 s if s.starts_with("AmazonEC2ContainerServiceV") => "ecs",
424 s if s.starts_with("AWSStepFunctions") => "states",
425 s if s.starts_with("AWSOrganizationsV") => "organizations",
426 "CertificateManager" => "acm",
427 "AnyScaleFrontendService" => "application-autoscaling",
428 "AWSWAF_20190729" => "wafv2",
431 "AmazonAthena" => "athena",
432 s if s.starts_with("Firehose_") => "firehose",
433 "AWSGlue" => "glue",
434 _ => return None,
435 };
436
437 Some(DetectedRequest {
438 service: service.to_string(),
439 action: action.to_string(),
440 protocol: AwsProtocol::Json,
441 })
442}
443
444fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
446 if REST_XML_SERVICES.contains(&service) {
447 Some(AwsProtocol::Rest)
448 } else if REST_JSON_SERVICES.contains(&service) {
449 Some(AwsProtocol::RestJson)
450 } else {
451 None
452 }
453}
454
455fn infer_service_from_action(action: &str) -> Option<String> {
459 match action {
460 "AssumeRole"
461 | "AssumeRoleWithSAML"
462 | "AssumeRoleWithWebIdentity"
463 | "GetCallerIdentity"
464 | "GetSessionToken"
465 | "GetFederationToken"
466 | "GetAccessKeyInfo"
467 | "DecodeAuthorizationMessage" => Some("sts".to_string()),
468 "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
469 | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
470 | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
471 | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
472 "VerifyEmailIdentity"
474 | "VerifyDomainIdentity"
475 | "VerifyDomainDkim"
476 | "ListIdentities"
477 | "GetIdentityVerificationAttributes"
478 | "GetIdentityDkimAttributes"
479 | "DeleteIdentity"
480 | "SetIdentityDkimEnabled"
481 | "SetIdentityNotificationTopic"
482 | "SetIdentityFeedbackForwardingEnabled"
483 | "GetIdentityNotificationAttributes"
484 | "GetIdentityMailFromDomainAttributes"
485 | "SetIdentityMailFromDomain"
486 | "SendEmail"
487 | "SendRawEmail"
488 | "SendTemplatedEmail"
489 | "SendBulkTemplatedEmail"
490 | "CreateTemplate"
491 | "GetTemplate"
492 | "ListTemplates"
493 | "DeleteTemplate"
494 | "UpdateTemplate"
495 | "CreateConfigurationSet"
496 | "DeleteConfigurationSet"
497 | "DescribeConfigurationSet"
498 | "ListConfigurationSets"
499 | "CreateConfigurationSetEventDestination"
500 | "UpdateConfigurationSetEventDestination"
501 | "DeleteConfigurationSetEventDestination"
502 | "GetSendQuota"
503 | "GetSendStatistics"
504 | "GetAccountSendingEnabled"
505 | "CreateReceiptRuleSet"
506 | "DeleteReceiptRuleSet"
507 | "DescribeReceiptRuleSet"
508 | "ListReceiptRuleSets"
509 | "CloneReceiptRuleSet"
510 | "SetActiveReceiptRuleSet"
511 | "ReorderReceiptRuleSet"
512 | "CreateReceiptRule"
513 | "DeleteReceiptRule"
514 | "DescribeReceiptRule"
515 | "UpdateReceiptRule"
516 | "CreateReceiptFilter"
517 | "DeleteReceiptFilter"
518 | "ListReceiptFilters" => Some("ses".to_string()),
519 _ => None,
520 }
521}
522
523fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
525 let auth = headers.get("authorization")?.to_str().ok()?;
526 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
527 Some(normalize_service_name(&info.service).to_string())
528}
529
530fn normalize_service_name(service: &str) -> &str {
542 match service {
543 "bedrock-runtime" => "bedrock",
544 "apigatewayv2" => "apigateway",
552 other => other,
553 }
554}
555
556pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
558 decode_form_urlencoded(body)
559}
560
561fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
562 let s = std::str::from_utf8(input).unwrap_or("");
563 let mut result = HashMap::new();
564 for pair in s.split('&') {
565 if pair.is_empty() {
566 continue;
567 }
568 let (key, value) = match pair.find('=') {
569 Some(pos) => (&pair[..pos], &pair[pos + 1..]),
570 None => (pair, ""),
571 };
572 result.insert(url_decode(key), url_decode(value));
573 }
574 result
575}
576
577fn url_decode(input: &str) -> String {
578 let mut result = String::with_capacity(input.len());
579 let mut bytes = input.bytes();
580 while let Some(b) = bytes.next() {
581 match b {
582 b'+' => result.push(' '),
583 b'%' => {
584 let high = bytes.next().and_then(from_hex);
585 let low = bytes.next().and_then(from_hex);
586 if let (Some(h), Some(l)) = (high, low) {
587 result.push((h << 4 | l) as char);
588 }
589 }
590 _ => result.push(b as char),
591 }
592 }
593 result
594}
595
596fn from_hex(b: u8) -> Option<u8> {
597 match b {
598 b'0'..=b'9' => Some(b - b'0'),
599 b'a'..=b'f' => Some(b - b'a' + 10),
600 b'A'..=b'F' => Some(b - b'A' + 10),
601 _ => None,
602 }
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn parse_amz_target_events() {
611 let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
612 assert_eq!(result.service, "events");
613 assert_eq!(result.action, "PutEvents");
614 assert_eq!(result.protocol, AwsProtocol::Json);
615 }
616
617 #[test]
618 fn parse_amz_target_ssm() {
619 let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
620 assert_eq!(result.service, "ssm");
621 assert_eq!(result.action, "GetParameter");
622 }
623
624 #[test]
625 fn parse_amz_target_kinesis() {
626 let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
627 assert_eq!(result.service, "kinesis");
628 assert_eq!(result.action, "ListStreams");
629 assert_eq!(result.protocol, AwsProtocol::Json);
630 }
631
632 #[test]
633 fn parse_query_body_basic() {
634 let body = Bytes::from(
635 "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
636 );
637 let params = parse_query_body(&body);
638 assert_eq!(params.get("Action").unwrap(), "SendMessage");
639 assert_eq!(params.get("MessageBody").unwrap(), "hello");
640 }
641
642 #[test]
643 fn parse_query_body_empty_returns_empty_map() {
644 let body = Bytes::from("");
645 let params = parse_query_body(&body);
646 assert!(params.is_empty());
647 }
648
649 #[test]
650 fn parse_query_body_duplicate_keys_last_wins() {
651 let body = Bytes::from("key=a&key=b");
652 let params = parse_query_body(&body);
653 assert_eq!(params.get("key").unwrap(), "b");
654 }
655
656 #[test]
657 fn parse_query_body_single_key() {
658 let body = Bytes::from("key=value");
659 let params = parse_query_body(&body);
660 assert_eq!(params.get("key").unwrap(), "value");
661 }
662
663 #[test]
664 fn parse_amz_target_ecs() {
665 let result = parse_amz_target("AmazonEC2ContainerServiceV20141113.ListClusters").unwrap();
666 assert_eq!(result.service, "ecs");
667 assert_eq!(result.action, "ListClusters");
668 assert_eq!(result.protocol, AwsProtocol::Json);
669 }
670
671 #[test]
672 fn parse_amz_target_invalid_returns_none() {
673 assert!(parse_amz_target("NoDotHere").is_none());
674 assert!(parse_amz_target("").is_none());
675 }
676
677 #[test]
678 fn parse_amz_target_various_prefixes() {
679 assert_eq!(
680 parse_amz_target("AmazonSQS.SendMessage").unwrap().service,
681 "sqs"
682 );
683 assert_eq!(
684 parse_amz_target("AmazonSNS.Publish").unwrap().service,
685 "sns"
686 );
687 assert_eq!(
688 parse_amz_target("DynamoDB_20120810.GetItem")
689 .unwrap()
690 .service,
691 "dynamodb"
692 );
693 assert_eq!(
694 parse_amz_target("Logs_20140328.PutLogEvents")
695 .unwrap()
696 .service,
697 "logs"
698 );
699 assert_eq!(
700 parse_amz_target("secretsmanager.GetSecretValue")
701 .unwrap()
702 .service,
703 "secretsmanager"
704 );
705 assert_eq!(
706 parse_amz_target("TrentService.Encrypt").unwrap().service,
707 "kms"
708 );
709 assert_eq!(
710 parse_amz_target("AWSCognitoIdentityProviderService.InitiateAuth")
711 .unwrap()
712 .service,
713 "cognito-idp"
714 );
715 assert_eq!(
716 parse_amz_target("AWSStepFunctions.StartExecution")
717 .unwrap()
718 .service,
719 "states"
720 );
721 assert_eq!(
722 parse_amz_target("AWSOrganizationsV20161128.CreateOrganization")
723 .unwrap()
724 .service,
725 "organizations"
726 );
727 assert!(parse_amz_target("UnknownServicePrefix.Action").is_none());
728 }
729
730 #[test]
731 fn infer_service_from_action_maps_sts() {
732 assert_eq!(
733 infer_service_from_action("AssumeRole").as_deref(),
734 Some("sts")
735 );
736 assert_eq!(
737 infer_service_from_action("GetCallerIdentity").as_deref(),
738 Some("sts")
739 );
740 }
741
742 #[test]
743 fn infer_service_from_action_maps_iam() {
744 assert_eq!(
745 infer_service_from_action("CreateUser").as_deref(),
746 Some("iam")
747 );
748 assert_eq!(
749 infer_service_from_action("ListRoles").as_deref(),
750 Some("iam")
751 );
752 }
753
754 #[test]
755 fn infer_service_from_action_maps_ses() {
756 assert_eq!(
757 infer_service_from_action("SendEmail").as_deref(),
758 Some("ses")
759 );
760 assert_eq!(
761 infer_service_from_action("ListIdentities").as_deref(),
762 Some("ses")
763 );
764 }
765
766 #[test]
767 fn infer_service_from_action_unknown_returns_none() {
768 assert!(infer_service_from_action("NotARealAction").is_none());
769 }
770
771 #[test]
772 fn rest_protocol_for_returns_none_for_non_rest_service() {
773 assert!(rest_protocol_for("sqs").is_none());
774 }
775
776 #[test]
777 fn url_decode_handles_percent_and_plus() {
778 assert_eq!(url_decode("hello+world"), "hello world");
779 assert_eq!(url_decode("hello%20world"), "hello world");
780 assert_eq!(url_decode("100%25"), "100%");
781 }
782
783 #[test]
784 fn url_decode_ignores_malformed_percent() {
785 assert_eq!(url_decode("%ZZ"), "");
786 }
787
788 #[test]
789 fn from_hex_valid_digits() {
790 assert_eq!(from_hex(b'0'), Some(0));
791 assert_eq!(from_hex(b'9'), Some(9));
792 assert_eq!(from_hex(b'a'), Some(10));
793 assert_eq!(from_hex(b'F'), Some(15));
794 }
795
796 #[test]
797 fn from_hex_invalid_returns_none() {
798 assert!(from_hex(b'g').is_none());
799 assert!(from_hex(b' ').is_none());
800 }
801
802 #[test]
803 fn detect_service_via_amz_target() {
804 let mut headers = HeaderMap::new();
805 headers.insert("x-amz-target", "AmazonSSM.GetParameter".parse().unwrap());
806 let query = HashMap::new();
807 let body = Bytes::new();
808 let detected = detect_service(&headers, &query, &body).unwrap();
809 assert_eq!(detected.service, "ssm");
810 assert_eq!(detected.action, "GetParameter");
811 }
812
813 #[test]
814 fn detect_service_via_query_action_with_inferred_service() {
815 let headers = HeaderMap::new();
816 let mut query = HashMap::new();
817 query.insert("Action".to_string(), "AssumeRole".to_string());
818 let body = Bytes::new();
819 let detected = detect_service(&headers, &query, &body).unwrap();
820 assert_eq!(detected.service, "sts");
821 assert_eq!(detected.action, "AssumeRole");
822 assert_eq!(detected.protocol, AwsProtocol::Query);
823 }
824
825 #[test]
826 fn detect_service_via_form_body() {
827 let headers = HeaderMap::new();
828 let query = HashMap::new();
829 let body = Bytes::from("Action=SendEmail&Source=x%40y.com");
830 let detected = detect_service(&headers, &query, &body).unwrap();
831 assert_eq!(detected.service, "ses");
832 assert_eq!(detected.action, "SendEmail");
833 }
834
835 #[test]
836 fn detect_service_via_sigv2_presigned() {
837 let headers = HeaderMap::new();
838 let mut query = HashMap::new();
839 query.insert("AWSAccessKeyId".to_string(), "AKID".to_string());
840 query.insert("Signature".to_string(), "sig".to_string());
841 query.insert("Expires".to_string(), "1234567890".to_string());
842 let body = Bytes::new();
843 let detected = detect_service(&headers, &query, &body).unwrap();
844 assert_eq!(detected.service, "s3");
845 assert_eq!(detected.protocol, AwsProtocol::Rest);
846 }
847
848 #[test]
849 fn detect_service_via_sigv4_presigned_credential() {
850 let headers = HeaderMap::new();
851 let mut query = HashMap::new();
852 query.insert(
853 "X-Amz-Credential".to_string(),
854 "AKID/20240101/us-east-1/s3/aws4_request".to_string(),
855 );
856 let body = Bytes::new();
857 let detected = detect_service(&headers, &query, &body).unwrap();
858 assert_eq!(detected.service, "s3");
859 assert_eq!(detected.protocol, AwsProtocol::Rest);
860 }
861
862 #[test]
863 fn detect_service_unknown_returns_none() {
864 let headers = HeaderMap::new();
865 let query = HashMap::new();
866 let body = Bytes::new();
867 assert!(detect_service(&headers, &query, &body).is_none());
868 }
869
870 #[test]
871 fn normalize_service_name_aliases_apigatewayv2_to_apigateway() {
872 assert_eq!(normalize_service_name("apigatewayv2"), "apigateway");
877 }
878
879 #[test]
880 fn normalize_service_name_aliases_bedrock_runtime_to_bedrock() {
881 assert_eq!(normalize_service_name("bedrock-runtime"), "bedrock");
886 }
887
888 #[test]
889 fn normalize_service_name_passes_through_unaliased_services() {
890 assert_eq!(normalize_service_name("bedrock"), "bedrock");
894 assert_eq!(normalize_service_name("s3"), "s3");
895 assert_eq!(normalize_service_name("lambda"), "lambda");
896 assert_eq!(normalize_service_name(""), "");
897 assert_eq!(
898 normalize_service_name("unknown-future-service"),
899 "unknown-future-service"
900 );
901 }
902
903 #[test]
904 fn detect_service_via_authorization_header_normalizes_bedrock_runtime() {
905 let mut headers = HeaderMap::new();
910 headers.insert(
911 "authorization",
912 "AWS4-HMAC-SHA256 \
913 Credential=AKID/20240101/us-east-1/bedrock-runtime/aws4_request, \
914 SignedHeaders=host, Signature=abc"
915 .parse()
916 .unwrap(),
917 );
918 let query = HashMap::new();
919 let body = Bytes::new();
920 let detected = detect_service(&headers, &query, &body).unwrap();
921 assert_eq!(detected.service, "bedrock");
922 assert_eq!(detected.protocol, AwsProtocol::RestJson);
923 }
924
925 #[test]
926 fn detect_service_via_sigv4_presigned_credential_normalizes_bedrock_runtime() {
927 let headers = HeaderMap::new();
931 let mut query = HashMap::new();
932 query.insert(
933 "X-Amz-Credential".to_string(),
934 "AKID/20240101/us-east-1/bedrock-runtime/aws4_request".to_string(),
935 );
936 let body = Bytes::new();
937 let detected = detect_service(&headers, &query, &body).unwrap();
938 assert_eq!(detected.service, "bedrock");
939 assert_eq!(detected.protocol, AwsProtocol::RestJson);
940 }
941
942 #[test]
943 fn parse_routing_host_localstack_basic() {
944 let h = parse_routing_host("sqs.us-east-1.localhost.localstack.cloud").unwrap();
945 assert_eq!(h.service, "sqs");
946 assert_eq!(h.region, "us-east-1");
947 assert!(h.bucket.is_none());
948 }
949
950 #[test]
951 fn parse_routing_host_localstack_with_port() {
952 let h = parse_routing_host("lambda.eu-west-1.localhost.localstack.cloud:4566").unwrap();
953 assert_eq!(h.service, "lambda");
954 assert_eq!(h.region, "eu-west-1");
955 assert!(h.bucket.is_none());
956 }
957
958 #[test]
959 fn parse_routing_host_case_insensitive() {
960 let h = parse_routing_host("SQS.US-EAST-1.LOCALHOST.LOCALSTACK.CLOUD:4566").unwrap();
961 assert_eq!(h.service, "sqs");
962 assert_eq!(h.region, "us-east-1");
963
964 let h = parse_routing_host("LAMBDA.US-EAST-1.AMAZONAWS.COM").unwrap();
965 assert_eq!(h.service, "lambda");
966 assert_eq!(h.region, "us-east-1");
967 }
968
969 #[test]
970 fn parse_routing_host_localstack_s3_virtual_hosted() {
971 let h =
972 parse_routing_host("my-bucket.s3.us-east-1.localhost.localstack.cloud:4566").unwrap();
973 assert_eq!(h.service, "s3");
974 assert_eq!(h.region, "us-east-1");
975 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
976 }
977
978 #[test]
979 fn parse_routing_host_localstack_s3_vhost_bucket_with_dots() {
980 let h = parse_routing_host("a.b.c.s3.us-east-1.localhost.localstack.cloud").unwrap();
981 assert_eq!(h.service, "s3");
982 assert_eq!(h.region, "us-east-1");
983 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
984 }
985
986 #[test]
987 fn parse_routing_host_aws_service_region() {
988 let h = parse_routing_host("sqs.us-east-1.amazonaws.com").unwrap();
989 assert_eq!(h.service, "sqs");
990 assert_eq!(h.region, "us-east-1");
991 assert!(h.bucket.is_none());
992
993 let h = parse_routing_host("dynamodb.eu-west-2.amazonaws.com:443").unwrap();
994 assert_eq!(h.service, "dynamodb");
995 assert_eq!(h.region, "eu-west-2");
996 }
997
998 #[test]
999 fn parse_routing_host_aws_s3_path_style_modern() {
1000 let h = parse_routing_host("s3.us-east-1.amazonaws.com").unwrap();
1001 assert_eq!(h.service, "s3");
1002 assert_eq!(h.region, "us-east-1");
1003 assert!(h.bucket.is_none());
1004 }
1005
1006 #[test]
1007 fn parse_routing_host_aws_s3_virtual_hosted_modern() {
1008 let h = parse_routing_host("my-bucket.s3.us-east-1.amazonaws.com").unwrap();
1009 assert_eq!(h.service, "s3");
1010 assert_eq!(h.region, "us-east-1");
1011 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1012 }
1013
1014 #[test]
1015 fn parse_routing_host_aws_s3_vhost_bucket_with_dots() {
1016 let h = parse_routing_host("a.b.c.s3.us-east-1.amazonaws.com").unwrap();
1017 assert_eq!(h.service, "s3");
1018 assert_eq!(h.region, "us-east-1");
1019 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1020 }
1021
1022 #[test]
1023 fn parse_routing_host_aws_s3_legacy_global() {
1024 let h = parse_routing_host("s3.amazonaws.com").unwrap();
1027 assert_eq!(h.service, "s3");
1028 assert_eq!(h.region, "us-east-1");
1029 assert!(h.bucket.is_none());
1030
1031 let h = parse_routing_host("my-bucket.s3.amazonaws.com").unwrap();
1032 assert_eq!(h.service, "s3");
1033 assert_eq!(h.region, "us-east-1");
1034 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1035 }
1036
1037 #[test]
1038 fn parse_routing_host_aws_s3_legacy_global_dotted_bucket() {
1039 let h = parse_routing_host("a.b.c.s3.amazonaws.com").unwrap();
1042 assert_eq!(h.service, "s3");
1043 assert_eq!(h.region, "us-east-1");
1044 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1045 }
1046
1047 #[test]
1048 fn parse_routing_host_aws_s3_dash_separated() {
1049 let h = parse_routing_host("s3-us-west-2.amazonaws.com").unwrap();
1051 assert_eq!(h.service, "s3");
1052 assert_eq!(h.region, "us-west-2");
1053 assert!(h.bucket.is_none());
1054
1055 let h = parse_routing_host("my-bucket.s3-us-west-2.amazonaws.com").unwrap();
1056 assert_eq!(h.service, "s3");
1057 assert_eq!(h.region, "us-west-2");
1058 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1059 }
1060
1061 #[test]
1062 fn parse_routing_host_rejects_plain_localhost() {
1063 assert!(parse_routing_host("localhost:4566").is_none());
1064 assert!(parse_routing_host("127.0.0.1:4566").is_none());
1065 }
1066
1067 #[test]
1068 fn parse_routing_host_rejects_unknown_suffix() {
1069 assert!(parse_routing_host("sqs.us-east-1.example.com").is_none());
1070 assert!(parse_routing_host("s3.us-east-1.aws").is_none());
1071 }
1072
1073 #[test]
1074 fn parse_routing_host_empty_and_malformed_rejected() {
1075 assert!(parse_routing_host("").is_none());
1076 assert!(parse_routing_host(".localhost.localstack.cloud").is_none());
1077 assert!(parse_routing_host("..localhost.localstack.cloud").is_none());
1078 assert!(parse_routing_host("sqs.localhost.localstack.cloud").is_none());
1079 assert!(parse_routing_host("foo.bar.baz.localhost.localstack.cloud").is_none());
1080 assert!(parse_routing_host(".amazonaws.com").is_none());
1081 assert!(parse_routing_host("amazonaws.com").is_none());
1082 }
1083
1084 #[test]
1085 fn detect_service_via_host_for_rest_service() {
1086 let mut headers = HeaderMap::new();
1087 headers.insert(
1088 "host",
1089 "s3.us-east-1.localhost.localstack.cloud:4566"
1090 .parse()
1091 .unwrap(),
1092 );
1093 let query = HashMap::new();
1094 let body = Bytes::new();
1095 let detected = detect_service(&headers, &query, &body).unwrap();
1096 assert_eq!(detected.service, "s3");
1097 assert_eq!(detected.protocol, AwsProtocol::Rest);
1098 }
1099
1100 #[test]
1101 fn detect_service_via_host_for_rest_json_service() {
1102 let mut headers = HeaderMap::new();
1103 headers.insert(
1104 "host",
1105 "lambda.us-east-1.localhost.localstack.cloud:4566"
1106 .parse()
1107 .unwrap(),
1108 );
1109 let query = HashMap::new();
1110 let body = Bytes::new();
1111 let detected = detect_service(&headers, &query, &body).unwrap();
1112 assert_eq!(detected.service, "lambda");
1113 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1114 }
1115
1116 #[test]
1117 fn detect_service_via_host_plus_query_action() {
1118 let mut headers = HeaderMap::new();
1119 headers.insert(
1120 "host",
1121 "sqs.us-east-1.localhost.localstack.cloud:4566"
1122 .parse()
1123 .unwrap(),
1124 );
1125 let mut query = HashMap::new();
1126 query.insert("Action".to_string(), "ListQueues".to_string());
1127 let body = Bytes::new();
1128 let detected = detect_service(&headers, &query, &body).unwrap();
1129 assert_eq!(detected.service, "sqs");
1130 assert_eq!(detected.action, "ListQueues");
1131 assert_eq!(detected.protocol, AwsProtocol::Query);
1132 }
1133
1134 #[test]
1135 fn detect_service_sigv4_wins_over_host() {
1136 let mut headers = HeaderMap::new();
1137 headers.insert(
1138 "authorization",
1139 "AWS4-HMAC-SHA256 Credential=AKID/20240101/us-east-1/s3/aws4_request, \
1140 SignedHeaders=host, Signature=abc"
1141 .parse()
1142 .unwrap(),
1143 );
1144 headers.insert(
1145 "host",
1146 "lambda.us-east-1.localhost.localstack.cloud:4566"
1147 .parse()
1148 .unwrap(),
1149 );
1150 let query = HashMap::new();
1151 let body = Bytes::new();
1152 let detected = detect_service(&headers, &query, &body).unwrap();
1153 assert_eq!(detected.service, "s3");
1155 assert_eq!(detected.protocol, AwsProtocol::Rest);
1156 }
1157
1158 #[test]
1159 fn detect_service_host_for_virtual_hosted_s3() {
1160 let mut headers = HeaderMap::new();
1161 headers.insert(
1162 "host",
1163 "my-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1164 .parse()
1165 .unwrap(),
1166 );
1167 let query = HashMap::new();
1168 let body = Bytes::new();
1169 let detected = detect_service(&headers, &query, &body).unwrap();
1170 assert_eq!(detected.service, "s3");
1171 assert_eq!(detected.protocol, AwsProtocol::Rest);
1172 }
1173}