1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8 Query,
11 Json,
14 Rest,
17 RestJson,
20}
21
22const REST_XML_SERVICES: &[&str] = &["s3", "cloudfront", "route53"];
24
25const REST_JSON_SERVICES: &[&str] = &[
27 "lambda",
28 "ses",
29 "apigateway",
30 "bedrock",
31 "bedrock-agent",
32 "bedrock-agent-runtime",
33 "scheduler",
34 "batch",
35 "pipes",
36 "rds-data",
37 "dsql",
38 "resource-groups",
39 "eks",
40];
41
42#[derive(Debug, Clone)]
44pub struct DetectedRequest {
45 pub service: String,
46 pub action: String,
47 pub protocol: AwsProtocol,
48}
49
50pub fn detect_service_headers_only(
57 headers: &HeaderMap,
58 query_params: &HashMap<String, String>,
59) -> Option<DetectedRequest> {
60 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
62 return parse_amz_target(target);
63 }
64 if let Some(action) = query_params.get("Action") {
65 let service = extract_service_from_auth(headers)
66 .or_else(|| infer_service_from_action(action))
67 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
68 if let Some(service) = service {
69 return Some(DetectedRequest {
70 service,
71 action: action.clone(),
72 protocol: AwsProtocol::Query,
73 });
74 }
75 }
76 if let Some(service) = extract_service_from_auth(headers) {
77 if let Some(protocol) = rest_protocol_for(&service) {
78 return Some(DetectedRequest {
79 service,
80 action: String::new(),
81 protocol,
82 });
83 }
84 }
85 if let Some(credential) = query_params.get("X-Amz-Credential") {
86 let parts: Vec<&str> = credential.split('/').collect();
87 if parts.len() >= 4 {
88 let service = normalize_service_name(parts[3]).to_string();
89 if let Some(protocol) = rest_protocol_for(&service) {
90 return Some(DetectedRequest {
91 service,
92 action: String::new(),
93 protocol,
94 });
95 }
96 }
97 }
98 if query_params.contains_key("AWSAccessKeyId")
99 && query_params.contains_key("Signature")
100 && query_params.contains_key("Expires")
101 {
102 return Some(DetectedRequest {
103 service: "s3".to_string(),
104 action: String::new(),
105 protocol: AwsProtocol::Rest,
106 });
107 }
108 if let Some(host_info) = parse_routing_host_from_headers(headers) {
109 if let Some(protocol) = rest_protocol_for(&host_info.service) {
110 return Some(DetectedRequest {
111 service: host_info.service,
112 action: String::new(),
113 protocol,
114 });
115 }
116 }
117 None
118}
119
120pub fn detect_service(
122 headers: &HeaderMap,
123 query_params: &HashMap<String, String>,
124 body: &Bytes,
125) -> Option<DetectedRequest> {
126 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
128 return parse_amz_target(target);
129 }
130
131 if let Some(action) = query_params.get("Action") {
133 let service = extract_service_from_auth(headers)
134 .or_else(|| infer_service_from_action(action))
135 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
136 if let Some(service) = service {
137 return Some(DetectedRequest {
138 service,
139 action: action.clone(),
140 protocol: AwsProtocol::Query,
141 });
142 }
143 }
144
145 {
147 let form_params = decode_form_urlencoded(body);
148
149 if let Some(action) = form_params.get("Action") {
150 let service = extract_service_from_auth(headers)
151 .or_else(|| infer_service_from_action(action))
152 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
153 if let Some(service) = service {
154 return Some(DetectedRequest {
155 service,
156 action: action.clone(),
157 protocol: AwsProtocol::Query,
158 });
159 }
160 }
161 }
162
163 if let Some(service) = extract_service_from_auth(headers) {
165 if let Some(protocol) = rest_protocol_for(&service) {
166 return Some(DetectedRequest {
167 service,
168 action: String::new(), protocol,
170 });
171 }
172 }
173
174 if let Some(credential) = query_params.get("X-Amz-Credential") {
176 let parts: Vec<&str> = credential.split('/').collect();
178 if parts.len() >= 4 {
179 let service = normalize_service_name(parts[3]).to_string();
180 if let Some(protocol) = rest_protocol_for(&service) {
181 return Some(DetectedRequest {
182 service,
183 action: String::new(),
184 protocol,
185 });
186 }
187 }
188 }
189
190 if query_params.contains_key("AWSAccessKeyId")
194 && query_params.contains_key("Signature")
195 && query_params.contains_key("Expires")
196 {
197 return Some(DetectedRequest {
198 service: "s3".to_string(),
199 action: String::new(),
200 protocol: AwsProtocol::Rest,
201 });
202 }
203
204 if let Some(host_info) = parse_routing_host_from_headers(headers) {
208 if let Some(protocol) = rest_protocol_for(&host_info.service) {
209 return Some(DetectedRequest {
210 service: host_info.service,
211 action: String::new(),
212 protocol,
213 });
214 }
215 }
216
217 None
218}
219
220#[derive(Debug, Clone, PartialEq, Eq)]
229pub struct RoutingHost {
230 pub service: String,
231 pub region: String,
232 pub bucket: Option<String>,
234}
235
236const LOCALSTACK_SUFFIX: &str = ".localhost.localstack.cloud";
237const AWS_SUFFIX: &str = ".amazonaws.com";
238
239pub fn parse_routing_host(host: &str) -> Option<RoutingHost> {
243 let hostname = host.split(':').next()?;
244 if hostname.is_empty() {
245 return None;
246 }
247 let hostname = hostname.to_ascii_lowercase();
248 if let Some(prefix) = hostname.strip_suffix(LOCALSTACK_SUFFIX) {
249 return parse_localstack_prefix(prefix);
250 }
251 if hostname == "amazonaws.com" {
252 return None;
253 }
254 if let Some(prefix) = hostname.strip_suffix(AWS_SUFFIX) {
255 return parse_aws_prefix(prefix);
256 }
257 None
258}
259
260pub fn parse_routing_host_from_headers(headers: &HeaderMap) -> Option<RoutingHost> {
262 let host = headers.get("host")?.to_str().ok()?;
263 parse_routing_host(host)
264}
265
266fn parse_localstack_prefix(prefix: &str) -> Option<RoutingHost> {
267 if prefix.is_empty() {
268 return None;
269 }
270 let labels: Vec<&str> = prefix.split('.').collect();
271 if labels.iter().any(|l| l.is_empty()) {
272 return None;
273 }
274 match labels.len() {
275 2 => Some(RoutingHost {
276 service: labels[0].to_string(),
277 region: labels[1].to_string(),
278 bucket: None,
279 }),
280 n if n >= 3 && labels[n - 2] == "s3" => {
281 let bucket = labels[..n - 2].join(".");
282 Some(RoutingHost {
283 service: "s3".to_string(),
284 region: labels[n - 1].to_string(),
285 bucket: Some(bucket),
286 })
287 }
288 n if n >= 3 && labels[n - 2] == "s3-accesspoint" => {
289 let bucket = labels[..n - 2].join(".");
290 Some(RoutingHost {
291 service: "s3".to_string(),
292 region: labels[n - 1].to_string(),
293 bucket: Some(bucket),
294 })
295 }
296 n if n >= 3 && labels[n - 2] == "s3-control" => Some(RoutingHost {
297 service: "s3".to_string(),
298 region: labels[n - 1].to_string(),
299 bucket: None,
300 }),
301 _ => None,
302 }
303}
304
305fn parse_aws_prefix(prefix: &str) -> Option<RoutingHost> {
317 if prefix.is_empty() {
318 return None;
319 }
320 let labels: Vec<&str> = prefix.split('.').collect();
321 if labels.iter().any(|l| l.is_empty()) {
322 return None;
323 }
324 let last = *labels.last()?;
325
326 if let Some(region) = last.strip_prefix("s3-") {
329 if !region.is_empty() {
330 let bucket = if labels.len() >= 2 {
331 Some(labels[..labels.len() - 1].join("."))
332 } else {
333 None
334 };
335 return Some(RoutingHost {
336 service: "s3".to_string(),
337 region: region.to_string(),
338 bucket,
339 });
340 }
341 }
342
343 if last == "s3" {
347 if labels.len() == 1 {
348 return Some(RoutingHost {
349 service: "s3".to_string(),
350 region: "us-east-1".to_string(),
351 bucket: None,
352 });
353 }
354 return Some(RoutingHost {
355 service: "s3".to_string(),
356 region: "us-east-1".to_string(),
357 bucket: Some(labels[..labels.len() - 1].join(".")),
358 });
359 }
360
361 if last == "s3-accesspoint" {
364 if labels.len() == 2 {
365 return Some(RoutingHost {
366 service: "s3".to_string(),
367 region: labels[0].to_string(),
368 bucket: None,
369 });
370 }
371 if labels.len() >= 3 {
375 let bucket = labels[..labels.len() - 2].join(".");
376 return Some(RoutingHost {
377 service: "s3".to_string(),
378 region: labels[labels.len() - 1].to_string(),
379 bucket: Some(bucket),
380 });
381 }
382 }
383
384 if labels.len() >= 2 && labels[labels.len() - 2] == "s3-control" {
387 return Some(RoutingHost {
388 service: "s3".to_string(),
389 region: last.to_string(),
390 bucket: None,
391 });
392 }
393
394 match labels.len() {
395 2 => Some(RoutingHost {
398 service: labels[0].to_string(),
399 region: labels[1].to_string(),
400 bucket: None,
401 }),
402 n if n >= 3 && labels[n - 2] == "s3" => {
404 let bucket = labels[..n - 2].join(".");
405 Some(RoutingHost {
406 service: "s3".to_string(),
407 region: labels[n - 1].to_string(),
408 bucket: Some(bucket),
409 })
410 }
411 _ => None,
412 }
413}
414
415fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
418 let (prefix, action) = target.rsplit_once('.')?;
419
420 let service = match prefix {
421 "AWSEvents" => "events",
422 "AmazonSSM" => "ssm",
423 "AmazonSQS" => "sqs",
424 "AmazonSNS" => "sns",
425 "DynamoDB_20120810" => "dynamodb",
426 "DynamoDBStreams_20120810" => "dynamodbstreams",
427 "Logs_20140328" => "logs",
428 s if s.starts_with("secretsmanager") => "secretsmanager",
429 s if s.starts_with("TrentService") => "kms",
430 s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
431 s if s.starts_with("AWSCognitoIdentityService") => "cognito-identity",
432 s if s.starts_with("Kinesis_20131202") => "kinesis",
433 s if s.starts_with("AmazonEC2ContainerRegistry_V") => "ecr",
434 s if s.starts_with("AmazonEC2ContainerServiceV") => "ecs",
435 s if s.starts_with("AWSStepFunctions") => "states",
436 s if s.starts_with("AWSOrganizationsV") => "organizations",
437 "CertificateManager" => "acm",
438 "AnyScaleFrontendService" => "application-autoscaling",
439 "AWSWAF_20190729" => "wafv2",
442 "AmazonAthena" => "athena",
443 s if s.starts_with("Firehose_") => "firehose",
444 "AWSGlue" => "glue",
445 "CloudApiService" => "cloudcontrolapi",
446 "ResourceGroupsTaggingAPI_20170126" => "tagging",
447 "AmazonMemoryDB" => "memorydb",
448 "Route53AutoNaming_v20170314" => "servicediscovery",
451 s if s.starts_with("GraniteServiceVersion") => "monitoring",
457 _ => return None,
458 };
459
460 Some(DetectedRequest {
461 service: service.to_string(),
462 action: action.to_string(),
463 protocol: AwsProtocol::Json,
464 })
465}
466
467fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
469 if REST_XML_SERVICES.contains(&service) {
470 Some(AwsProtocol::Rest)
471 } else if REST_JSON_SERVICES.contains(&service) {
472 Some(AwsProtocol::RestJson)
473 } else {
474 None
475 }
476}
477
478fn infer_service_from_action(action: &str) -> Option<String> {
482 match action {
483 "AssumeRole"
484 | "AssumeRoleWithSAML"
485 | "AssumeRoleWithWebIdentity"
486 | "GetCallerIdentity"
487 | "GetSessionToken"
488 | "GetFederationToken"
489 | "GetAccessKeyInfo"
490 | "DecodeAuthorizationMessage" => Some("sts".to_string()),
491 "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
492 | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
493 | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
494 | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
495 "VerifyEmailIdentity"
497 | "VerifyDomainIdentity"
498 | "VerifyDomainDkim"
499 | "ListIdentities"
500 | "GetIdentityVerificationAttributes"
501 | "GetIdentityDkimAttributes"
502 | "DeleteIdentity"
503 | "SetIdentityDkimEnabled"
504 | "SetIdentityNotificationTopic"
505 | "SetIdentityFeedbackForwardingEnabled"
506 | "GetIdentityNotificationAttributes"
507 | "GetIdentityMailFromDomainAttributes"
508 | "SetIdentityMailFromDomain"
509 | "SendEmail"
510 | "SendRawEmail"
511 | "SendTemplatedEmail"
512 | "SendBulkTemplatedEmail"
513 | "CreateTemplate"
514 | "GetTemplate"
515 | "ListTemplates"
516 | "DeleteTemplate"
517 | "UpdateTemplate"
518 | "CreateConfigurationSet"
519 | "DeleteConfigurationSet"
520 | "DescribeConfigurationSet"
521 | "ListConfigurationSets"
522 | "CreateConfigurationSetEventDestination"
523 | "UpdateConfigurationSetEventDestination"
524 | "DeleteConfigurationSetEventDestination"
525 | "GetSendQuota"
526 | "GetSendStatistics"
527 | "GetAccountSendingEnabled"
528 | "CreateReceiptRuleSet"
529 | "DeleteReceiptRuleSet"
530 | "DescribeReceiptRuleSet"
531 | "ListReceiptRuleSets"
532 | "CloneReceiptRuleSet"
533 | "SetActiveReceiptRuleSet"
534 | "ReorderReceiptRuleSet"
535 | "CreateReceiptRule"
536 | "DeleteReceiptRule"
537 | "DescribeReceiptRule"
538 | "UpdateReceiptRule"
539 | "CreateReceiptFilter"
540 | "DeleteReceiptFilter"
541 | "ListReceiptFilters" => Some("ses".to_string()),
542 "ConfirmSubscription" | "Unsubscribe" => Some("sns".to_string()),
546 _ => None,
547 }
548}
549
550fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
552 let auth = headers.get("authorization")?.to_str().ok()?;
553 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
554 Some(normalize_service_name(&info.service).to_string())
555}
556
557fn normalize_service_name(service: &str) -> &str {
569 match service {
570 "bedrock-runtime" => "bedrock",
571 "apigatewayv2" => "apigateway",
579 other => other,
580 }
581}
582
583pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
585 decode_form_urlencoded(body)
586}
587
588pub fn flatten_json_to_query(body: &Bytes) -> HashMap<String, String> {
605 let mut out = HashMap::new();
606 let Ok(value) = serde_json::from_slice::<serde_json::Value>(body) else {
607 return out;
608 };
609 if value.is_object() {
610 flatten_json_value("", &value, &mut out);
611 }
612 out
613}
614
615fn flatten_json_value(prefix: &str, value: &serde_json::Value, out: &mut HashMap<String, String>) {
616 match value {
617 serde_json::Value::Object(map) => {
618 for (k, v) in map {
619 let child = if prefix.is_empty() {
620 k.clone()
621 } else {
622 format!("{prefix}.{k}")
623 };
624 flatten_json_value(&child, v, out);
625 }
626 }
627 serde_json::Value::Array(items) => {
628 for (i, v) in items.iter().enumerate() {
629 let child = format!("{prefix}.member.{}", i + 1);
630 flatten_json_value(&child, v, out);
631 }
632 }
633 serde_json::Value::Null => {}
634 serde_json::Value::String(s) => {
635 out.insert(prefix.to_string(), s.clone());
636 }
637 serde_json::Value::Bool(b) => {
638 out.insert(prefix.to_string(), b.to_string());
639 }
640 serde_json::Value::Number(n) => {
641 out.insert(prefix.to_string(), n.to_string());
642 }
643 }
644}
645
646fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
647 let s = std::str::from_utf8(input).unwrap_or("");
648 let mut result = HashMap::new();
649 for pair in s.split('&') {
650 if pair.is_empty() {
651 continue;
652 }
653 let (key, value) = match pair.find('=') {
654 Some(pos) => (&pair[..pos], &pair[pos + 1..]),
655 None => (pair, ""),
656 };
657 result.insert(url_decode(key), url_decode(value));
658 }
659 result
660}
661
662fn url_decode(input: &str) -> String {
663 let mut result = String::with_capacity(input.len());
664 let mut bytes = input.bytes();
665 while let Some(b) = bytes.next() {
666 match b {
667 b'+' => result.push(' '),
668 b'%' => {
669 let high = bytes.next().and_then(from_hex);
670 let low = bytes.next().and_then(from_hex);
671 if let (Some(h), Some(l)) = (high, low) {
672 result.push((h << 4 | l) as char);
673 }
674 }
675 _ => result.push(b as char),
676 }
677 }
678 result
679}
680
681fn from_hex(b: u8) -> Option<u8> {
682 match b {
683 b'0'..=b'9' => Some(b - b'0'),
684 b'a'..=b'f' => Some(b - b'a' + 10),
685 b'A'..=b'F' => Some(b - b'A' + 10),
686 _ => None,
687 }
688}
689
690#[cfg(test)]
691mod tests {
692 use super::*;
693
694 #[test]
695 fn parse_amz_target_events() {
696 let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
697 assert_eq!(result.service, "events");
698 assert_eq!(result.action, "PutEvents");
699 assert_eq!(result.protocol, AwsProtocol::Json);
700 }
701
702 #[test]
703 fn parse_amz_target_ssm() {
704 let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
705 assert_eq!(result.service, "ssm");
706 assert_eq!(result.action, "GetParameter");
707 }
708
709 #[test]
710 fn parse_amz_target_kinesis() {
711 let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
712 assert_eq!(result.service, "kinesis");
713 assert_eq!(result.action, "ListStreams");
714 assert_eq!(result.protocol, AwsProtocol::Json);
715 }
716
717 #[test]
718 fn parse_query_body_basic() {
719 let body = Bytes::from(
720 "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
721 );
722 let params = parse_query_body(&body);
723 assert_eq!(params.get("Action").unwrap(), "SendMessage");
724 assert_eq!(params.get("MessageBody").unwrap(), "hello");
725 }
726
727 #[test]
728 fn parse_query_body_empty_returns_empty_map() {
729 let body = Bytes::from("");
730 let params = parse_query_body(&body);
731 assert!(params.is_empty());
732 }
733
734 #[test]
735 fn parse_query_body_duplicate_keys_last_wins() {
736 let body = Bytes::from("key=a&key=b");
737 let params = parse_query_body(&body);
738 assert_eq!(params.get("key").unwrap(), "b");
739 }
740
741 #[test]
742 fn parse_query_body_single_key() {
743 let body = Bytes::from("key=value");
744 let params = parse_query_body(&body);
745 assert_eq!(params.get("key").unwrap(), "value");
746 }
747
748 #[test]
749 fn parse_amz_target_ecs() {
750 let result = parse_amz_target("AmazonEC2ContainerServiceV20141113.ListClusters").unwrap();
751 assert_eq!(result.service, "ecs");
752 assert_eq!(result.action, "ListClusters");
753 assert_eq!(result.protocol, AwsProtocol::Json);
754 }
755
756 #[test]
757 fn parse_amz_target_invalid_returns_none() {
758 assert!(parse_amz_target("NoDotHere").is_none());
759 assert!(parse_amz_target("").is_none());
760 }
761
762 #[test]
763 fn parse_amz_target_cloudwatch_json() {
764 let result = parse_amz_target("GraniteServiceVersion20100801.PutMetricData").unwrap();
766 assert_eq!(result.service, "monitoring");
767 assert_eq!(result.action, "PutMetricData");
768 assert_eq!(result.protocol, AwsProtocol::Json);
769 }
770
771 #[test]
772 fn flatten_json_to_query_nested() {
773 let body = Bytes::from(
774 serde_json::json!({
775 "Namespace": "MyApp",
776 "MetricData": [{
777 "MetricName": "Latency",
778 "Value": 12.5,
779 "StatisticValues": {"SampleCount": 3, "Sum": 10},
780 "Dimensions": [{"Name": "Endpoint", "Value": "/api"}]
781 }]
782 })
783 .to_string(),
784 );
785 let flat = flatten_json_to_query(&body);
786 assert_eq!(flat.get("Namespace").unwrap(), "MyApp");
787 assert_eq!(
788 flat.get("MetricData.member.1.MetricName").unwrap(),
789 "Latency"
790 );
791 assert_eq!(flat.get("MetricData.member.1.Value").unwrap(), "12.5");
792 assert_eq!(
793 flat.get("MetricData.member.1.StatisticValues.SampleCount")
794 .unwrap(),
795 "3"
796 );
797 assert_eq!(
798 flat.get("MetricData.member.1.Dimensions.member.1.Name")
799 .unwrap(),
800 "Endpoint"
801 );
802 assert_eq!(
803 flat.get("MetricData.member.1.Dimensions.member.1.Value")
804 .unwrap(),
805 "/api"
806 );
807 }
808
809 #[test]
810 fn flatten_json_to_query_non_object_is_empty() {
811 assert!(flatten_json_to_query(&Bytes::from_static(b"[]")).is_empty());
812 assert!(flatten_json_to_query(&Bytes::from_static(b"not json")).is_empty());
813 }
814
815 #[test]
816 fn parse_amz_target_various_prefixes() {
817 assert_eq!(
818 parse_amz_target("AmazonSQS.SendMessage").unwrap().service,
819 "sqs"
820 );
821 assert_eq!(
822 parse_amz_target("AmazonSNS.Publish").unwrap().service,
823 "sns"
824 );
825 assert_eq!(
826 parse_amz_target("DynamoDB_20120810.GetItem")
827 .unwrap()
828 .service,
829 "dynamodb"
830 );
831 assert_eq!(
832 parse_amz_target("Logs_20140328.PutLogEvents")
833 .unwrap()
834 .service,
835 "logs"
836 );
837 assert_eq!(
838 parse_amz_target("secretsmanager.GetSecretValue")
839 .unwrap()
840 .service,
841 "secretsmanager"
842 );
843 assert_eq!(
844 parse_amz_target("TrentService.Encrypt").unwrap().service,
845 "kms"
846 );
847 assert_eq!(
848 parse_amz_target("AWSCognitoIdentityProviderService.InitiateAuth")
849 .unwrap()
850 .service,
851 "cognito-idp"
852 );
853 assert_eq!(
854 parse_amz_target("AWSStepFunctions.StartExecution")
855 .unwrap()
856 .service,
857 "states"
858 );
859 assert_eq!(
860 parse_amz_target("AWSOrganizationsV20161128.CreateOrganization")
861 .unwrap()
862 .service,
863 "organizations"
864 );
865 assert!(parse_amz_target("UnknownServicePrefix.Action").is_none());
866 }
867
868 #[test]
869 fn infer_service_from_action_maps_sts() {
870 assert_eq!(
871 infer_service_from_action("AssumeRole").as_deref(),
872 Some("sts")
873 );
874 assert_eq!(
875 infer_service_from_action("GetCallerIdentity").as_deref(),
876 Some("sts")
877 );
878 }
879
880 #[test]
881 fn infer_service_from_action_maps_iam() {
882 assert_eq!(
883 infer_service_from_action("CreateUser").as_deref(),
884 Some("iam")
885 );
886 assert_eq!(
887 infer_service_from_action("ListRoles").as_deref(),
888 Some("iam")
889 );
890 }
891
892 #[test]
893 fn infer_service_from_action_maps_ses() {
894 assert_eq!(
895 infer_service_from_action("SendEmail").as_deref(),
896 Some("ses")
897 );
898 assert_eq!(
899 infer_service_from_action("ListIdentities").as_deref(),
900 Some("ses")
901 );
902 }
903
904 #[test]
905 fn infer_service_from_action_maps_sns_confirmation_flow() {
906 assert_eq!(
909 infer_service_from_action("ConfirmSubscription").as_deref(),
910 Some("sns")
911 );
912 assert_eq!(
913 infer_service_from_action("Unsubscribe").as_deref(),
914 Some("sns")
915 );
916 }
917
918 #[test]
919 fn detect_service_routes_unsigned_confirm_subscription_to_sns() {
920 let mut headers = HeaderMap::new();
923 headers.insert("host", "localhost:4566".parse().unwrap());
924 let mut query_params = HashMap::new();
925 query_params.insert("Action".to_string(), "ConfirmSubscription".to_string());
926 query_params.insert(
927 "TopicArn".to_string(),
928 "arn:aws:sns:us-east-1:000000000000:t".to_string(),
929 );
930 query_params.insert("Token".to_string(), "abc123".to_string());
931
932 let detected = detect_service(&headers, &query_params, &Bytes::new())
933 .expect("ConfirmSubscription must route to a service");
934 assert_eq!(detected.service, "sns");
935 assert_eq!(detected.action, "ConfirmSubscription");
936 assert_eq!(detected.protocol, AwsProtocol::Query);
937 }
938
939 #[test]
940 fn infer_service_from_action_unknown_returns_none() {
941 assert!(infer_service_from_action("NotARealAction").is_none());
942 }
943
944 #[test]
945 fn rest_protocol_for_returns_none_for_non_rest_service() {
946 assert!(rest_protocol_for("sqs").is_none());
947 }
948
949 #[test]
950 fn url_decode_handles_percent_and_plus() {
951 assert_eq!(url_decode("hello+world"), "hello world");
952 assert_eq!(url_decode("hello%20world"), "hello world");
953 assert_eq!(url_decode("100%25"), "100%");
954 }
955
956 #[test]
957 fn url_decode_ignores_malformed_percent() {
958 assert_eq!(url_decode("%ZZ"), "");
959 }
960
961 #[test]
962 fn from_hex_valid_digits() {
963 assert_eq!(from_hex(b'0'), Some(0));
964 assert_eq!(from_hex(b'9'), Some(9));
965 assert_eq!(from_hex(b'a'), Some(10));
966 assert_eq!(from_hex(b'F'), Some(15));
967 }
968
969 #[test]
970 fn from_hex_invalid_returns_none() {
971 assert!(from_hex(b'g').is_none());
972 assert!(from_hex(b' ').is_none());
973 }
974
975 #[test]
976 fn detect_service_via_amz_target() {
977 let mut headers = HeaderMap::new();
978 headers.insert("x-amz-target", "AmazonSSM.GetParameter".parse().unwrap());
979 let query = HashMap::new();
980 let body = Bytes::new();
981 let detected = detect_service(&headers, &query, &body).unwrap();
982 assert_eq!(detected.service, "ssm");
983 assert_eq!(detected.action, "GetParameter");
984 }
985
986 #[test]
987 fn detect_service_via_query_action_with_inferred_service() {
988 let headers = HeaderMap::new();
989 let mut query = HashMap::new();
990 query.insert("Action".to_string(), "AssumeRole".to_string());
991 let body = Bytes::new();
992 let detected = detect_service(&headers, &query, &body).unwrap();
993 assert_eq!(detected.service, "sts");
994 assert_eq!(detected.action, "AssumeRole");
995 assert_eq!(detected.protocol, AwsProtocol::Query);
996 }
997
998 #[test]
999 fn detect_service_via_form_body() {
1000 let headers = HeaderMap::new();
1001 let query = HashMap::new();
1002 let body = Bytes::from("Action=SendEmail&Source=x%40y.com");
1003 let detected = detect_service(&headers, &query, &body).unwrap();
1004 assert_eq!(detected.service, "ses");
1005 assert_eq!(detected.action, "SendEmail");
1006 }
1007
1008 #[test]
1009 fn detect_service_via_sigv2_presigned() {
1010 let headers = HeaderMap::new();
1011 let mut query = HashMap::new();
1012 query.insert("AWSAccessKeyId".to_string(), "AKID".to_string());
1013 query.insert("Signature".to_string(), "sig".to_string());
1014 query.insert("Expires".to_string(), "1234567890".to_string());
1015 let body = Bytes::new();
1016 let detected = detect_service(&headers, &query, &body).unwrap();
1017 assert_eq!(detected.service, "s3");
1018 assert_eq!(detected.protocol, AwsProtocol::Rest);
1019 }
1020
1021 #[test]
1022 fn detect_service_via_sigv4_presigned_credential() {
1023 let headers = HeaderMap::new();
1024 let mut query = HashMap::new();
1025 query.insert(
1026 "X-Amz-Credential".to_string(),
1027 "AKID/20240101/us-east-1/s3/aws4_request".to_string(),
1028 );
1029 let body = Bytes::new();
1030 let detected = detect_service(&headers, &query, &body).unwrap();
1031 assert_eq!(detected.service, "s3");
1032 assert_eq!(detected.protocol, AwsProtocol::Rest);
1033 }
1034
1035 #[test]
1036 fn detect_service_unknown_returns_none() {
1037 let headers = HeaderMap::new();
1038 let query = HashMap::new();
1039 let body = Bytes::new();
1040 assert!(detect_service(&headers, &query, &body).is_none());
1041 }
1042
1043 #[test]
1044 fn normalize_service_name_aliases_apigatewayv2_to_apigateway() {
1045 assert_eq!(normalize_service_name("apigatewayv2"), "apigateway");
1050 }
1051
1052 #[test]
1053 fn normalize_service_name_aliases_bedrock_runtime_to_bedrock() {
1054 assert_eq!(normalize_service_name("bedrock-runtime"), "bedrock");
1059 }
1060
1061 #[test]
1062 fn normalize_service_name_passes_through_unaliased_services() {
1063 assert_eq!(normalize_service_name("bedrock"), "bedrock");
1067 assert_eq!(normalize_service_name("s3"), "s3");
1068 assert_eq!(normalize_service_name("lambda"), "lambda");
1069 assert_eq!(normalize_service_name(""), "");
1070 assert_eq!(
1071 normalize_service_name("unknown-future-service"),
1072 "unknown-future-service"
1073 );
1074 }
1075
1076 #[test]
1077 fn detect_service_via_authorization_header_normalizes_bedrock_runtime() {
1078 let mut headers = HeaderMap::new();
1083 headers.insert(
1084 "authorization",
1085 "AWS4-HMAC-SHA256 \
1086 Credential=AKID/20240101/us-east-1/bedrock-runtime/aws4_request, \
1087 SignedHeaders=host, Signature=abc"
1088 .parse()
1089 .unwrap(),
1090 );
1091 let query = HashMap::new();
1092 let body = Bytes::new();
1093 let detected = detect_service(&headers, &query, &body).unwrap();
1094 assert_eq!(detected.service, "bedrock");
1095 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1096 }
1097
1098 #[test]
1099 fn detect_service_via_sigv4_presigned_credential_normalizes_bedrock_runtime() {
1100 let headers = HeaderMap::new();
1104 let mut query = HashMap::new();
1105 query.insert(
1106 "X-Amz-Credential".to_string(),
1107 "AKID/20240101/us-east-1/bedrock-runtime/aws4_request".to_string(),
1108 );
1109 let body = Bytes::new();
1110 let detected = detect_service(&headers, &query, &body).unwrap();
1111 assert_eq!(detected.service, "bedrock");
1112 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1113 }
1114
1115 #[test]
1116 fn parse_routing_host_localstack_basic() {
1117 let h = parse_routing_host("sqs.us-east-1.localhost.localstack.cloud").unwrap();
1118 assert_eq!(h.service, "sqs");
1119 assert_eq!(h.region, "us-east-1");
1120 assert!(h.bucket.is_none());
1121 }
1122
1123 #[test]
1124 fn parse_routing_host_localstack_with_port() {
1125 let h = parse_routing_host("lambda.eu-west-1.localhost.localstack.cloud:4566").unwrap();
1126 assert_eq!(h.service, "lambda");
1127 assert_eq!(h.region, "eu-west-1");
1128 assert!(h.bucket.is_none());
1129 }
1130
1131 #[test]
1132 fn parse_routing_host_case_insensitive() {
1133 let h = parse_routing_host("SQS.US-EAST-1.LOCALHOST.LOCALSTACK.CLOUD:4566").unwrap();
1134 assert_eq!(h.service, "sqs");
1135 assert_eq!(h.region, "us-east-1");
1136
1137 let h = parse_routing_host("LAMBDA.US-EAST-1.AMAZONAWS.COM").unwrap();
1138 assert_eq!(h.service, "lambda");
1139 assert_eq!(h.region, "us-east-1");
1140 }
1141
1142 #[test]
1143 fn parse_routing_host_localstack_s3_virtual_hosted() {
1144 let h =
1145 parse_routing_host("my-bucket.s3.us-east-1.localhost.localstack.cloud:4566").unwrap();
1146 assert_eq!(h.service, "s3");
1147 assert_eq!(h.region, "us-east-1");
1148 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1149 }
1150
1151 #[test]
1152 fn parse_routing_host_localstack_s3_vhost_bucket_with_dots() {
1153 let h = parse_routing_host("a.b.c.s3.us-east-1.localhost.localstack.cloud").unwrap();
1154 assert_eq!(h.service, "s3");
1155 assert_eq!(h.region, "us-east-1");
1156 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1157 }
1158
1159 #[test]
1160 fn parse_routing_host_aws_service_region() {
1161 let h = parse_routing_host("sqs.us-east-1.amazonaws.com").unwrap();
1162 assert_eq!(h.service, "sqs");
1163 assert_eq!(h.region, "us-east-1");
1164 assert!(h.bucket.is_none());
1165
1166 let h = parse_routing_host("dynamodb.eu-west-2.amazonaws.com:443").unwrap();
1167 assert_eq!(h.service, "dynamodb");
1168 assert_eq!(h.region, "eu-west-2");
1169 }
1170
1171 #[test]
1172 fn parse_routing_host_aws_s3_path_style_modern() {
1173 let h = parse_routing_host("s3.us-east-1.amazonaws.com").unwrap();
1174 assert_eq!(h.service, "s3");
1175 assert_eq!(h.region, "us-east-1");
1176 assert!(h.bucket.is_none());
1177 }
1178
1179 #[test]
1180 fn parse_routing_host_aws_s3_virtual_hosted_modern() {
1181 let h = parse_routing_host("my-bucket.s3.us-east-1.amazonaws.com").unwrap();
1182 assert_eq!(h.service, "s3");
1183 assert_eq!(h.region, "us-east-1");
1184 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1185 }
1186
1187 #[test]
1188 fn parse_routing_host_aws_s3_vhost_bucket_with_dots() {
1189 let h = parse_routing_host("a.b.c.s3.us-east-1.amazonaws.com").unwrap();
1190 assert_eq!(h.service, "s3");
1191 assert_eq!(h.region, "us-east-1");
1192 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1193 }
1194
1195 #[test]
1196 fn parse_routing_host_aws_s3_legacy_global() {
1197 let h = parse_routing_host("s3.amazonaws.com").unwrap();
1200 assert_eq!(h.service, "s3");
1201 assert_eq!(h.region, "us-east-1");
1202 assert!(h.bucket.is_none());
1203
1204 let h = parse_routing_host("my-bucket.s3.amazonaws.com").unwrap();
1205 assert_eq!(h.service, "s3");
1206 assert_eq!(h.region, "us-east-1");
1207 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1208 }
1209
1210 #[test]
1211 fn parse_routing_host_aws_s3_legacy_global_dotted_bucket() {
1212 let h = parse_routing_host("a.b.c.s3.amazonaws.com").unwrap();
1215 assert_eq!(h.service, "s3");
1216 assert_eq!(h.region, "us-east-1");
1217 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1218 }
1219
1220 #[test]
1221 fn parse_routing_host_aws_s3_dash_separated() {
1222 let h = parse_routing_host("s3-us-west-2.amazonaws.com").unwrap();
1224 assert_eq!(h.service, "s3");
1225 assert_eq!(h.region, "us-west-2");
1226 assert!(h.bucket.is_none());
1227
1228 let h = parse_routing_host("my-bucket.s3-us-west-2.amazonaws.com").unwrap();
1229 assert_eq!(h.service, "s3");
1230 assert_eq!(h.region, "us-west-2");
1231 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1232 }
1233
1234 #[test]
1235 fn parse_routing_host_rejects_plain_localhost() {
1236 assert!(parse_routing_host("localhost:4566").is_none());
1237 assert!(parse_routing_host("127.0.0.1:4566").is_none());
1238 }
1239
1240 #[test]
1241 fn parse_routing_host_rejects_unknown_suffix() {
1242 assert!(parse_routing_host("sqs.us-east-1.example.com").is_none());
1243 assert!(parse_routing_host("s3.us-east-1.aws").is_none());
1244 }
1245
1246 #[test]
1247 fn parse_routing_host_empty_and_malformed_rejected() {
1248 assert!(parse_routing_host("").is_none());
1249 assert!(parse_routing_host(".localhost.localstack.cloud").is_none());
1250 assert!(parse_routing_host("..localhost.localstack.cloud").is_none());
1251 assert!(parse_routing_host("sqs.localhost.localstack.cloud").is_none());
1252 assert!(parse_routing_host("foo.bar.baz.localhost.localstack.cloud").is_none());
1253 assert!(parse_routing_host(".amazonaws.com").is_none());
1254 assert!(parse_routing_host("amazonaws.com").is_none());
1255 }
1256
1257 #[test]
1258 fn parse_routing_host_bare_s3_accesspoint_does_not_panic() {
1259 assert!(parse_routing_host("s3-accesspoint").is_none());
1263 }
1264
1265 #[test]
1266 fn detect_service_via_host_for_rest_service() {
1267 let mut headers = HeaderMap::new();
1268 headers.insert(
1269 "host",
1270 "s3.us-east-1.localhost.localstack.cloud:4566"
1271 .parse()
1272 .unwrap(),
1273 );
1274 let query = HashMap::new();
1275 let body = Bytes::new();
1276 let detected = detect_service(&headers, &query, &body).unwrap();
1277 assert_eq!(detected.service, "s3");
1278 assert_eq!(detected.protocol, AwsProtocol::Rest);
1279 }
1280
1281 #[test]
1282 fn detect_service_via_host_for_rest_json_service() {
1283 let mut headers = HeaderMap::new();
1284 headers.insert(
1285 "host",
1286 "lambda.us-east-1.localhost.localstack.cloud:4566"
1287 .parse()
1288 .unwrap(),
1289 );
1290 let query = HashMap::new();
1291 let body = Bytes::new();
1292 let detected = detect_service(&headers, &query, &body).unwrap();
1293 assert_eq!(detected.service, "lambda");
1294 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1295 }
1296
1297 #[test]
1298 fn detect_service_via_host_plus_query_action() {
1299 let mut headers = HeaderMap::new();
1300 headers.insert(
1301 "host",
1302 "sqs.us-east-1.localhost.localstack.cloud:4566"
1303 .parse()
1304 .unwrap(),
1305 );
1306 let mut query = HashMap::new();
1307 query.insert("Action".to_string(), "ListQueues".to_string());
1308 let body = Bytes::new();
1309 let detected = detect_service(&headers, &query, &body).unwrap();
1310 assert_eq!(detected.service, "sqs");
1311 assert_eq!(detected.action, "ListQueues");
1312 assert_eq!(detected.protocol, AwsProtocol::Query);
1313 }
1314
1315 #[test]
1316 fn detect_service_sigv4_wins_over_host() {
1317 let mut headers = HeaderMap::new();
1318 headers.insert(
1319 "authorization",
1320 "AWS4-HMAC-SHA256 Credential=AKID/20240101/us-east-1/s3/aws4_request, \
1321 SignedHeaders=host, Signature=abc"
1322 .parse()
1323 .unwrap(),
1324 );
1325 headers.insert(
1326 "host",
1327 "lambda.us-east-1.localhost.localstack.cloud:4566"
1328 .parse()
1329 .unwrap(),
1330 );
1331 let query = HashMap::new();
1332 let body = Bytes::new();
1333 let detected = detect_service(&headers, &query, &body).unwrap();
1334 assert_eq!(detected.service, "s3");
1336 assert_eq!(detected.protocol, AwsProtocol::Rest);
1337 }
1338
1339 #[test]
1340 fn detect_service_host_for_virtual_hosted_s3() {
1341 let mut headers = HeaderMap::new();
1342 headers.insert(
1343 "host",
1344 "my-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1345 .parse()
1346 .unwrap(),
1347 );
1348 let query = HashMap::new();
1349 let body = Bytes::new();
1350 let detected = detect_service(&headers, &query, &body).unwrap();
1351 assert_eq!(detected.service, "s3");
1352 assert_eq!(detected.protocol, AwsProtocol::Rest);
1353 }
1354}