1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8 Query,
11 Json,
14 Rest,
17 RestJson,
20}
21
22const REST_XML_SERVICES: &[&str] = &["s3", "cloudfront", "route53"];
24
25const REST_JSON_SERVICES: &[&str] = &[
27 "lambda",
28 "ses",
29 "apigateway",
30 "bedrock",
31 "bedrock-agent",
32 "bedrock-agent-runtime",
33 "scheduler",
34 "batch",
35 "pipes",
36 "rds-data",
37 "dsql",
38 "resource-groups",
39 "eks",
40 "account",
41];
42
43#[derive(Debug, Clone)]
45pub struct DetectedRequest {
46 pub service: String,
47 pub action: String,
48 pub protocol: AwsProtocol,
49}
50
51pub fn detect_service_headers_only(
58 headers: &HeaderMap,
59 query_params: &HashMap<String, String>,
60) -> Option<DetectedRequest> {
61 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
63 return parse_amz_target(target);
64 }
65 if let Some(action) = query_params.get("Action") {
66 let service = extract_service_from_auth(headers)
67 .or_else(|| infer_service_from_action(action))
68 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
69 if let Some(service) = service {
70 return Some(DetectedRequest {
71 service,
72 action: action.clone(),
73 protocol: AwsProtocol::Query,
74 });
75 }
76 }
77 if let Some(service) = extract_service_from_auth(headers) {
78 if let Some(protocol) = rest_protocol_for(&service) {
79 return Some(DetectedRequest {
80 service,
81 action: String::new(),
82 protocol,
83 });
84 }
85 }
86 if let Some(credential) = query_params.get("X-Amz-Credential") {
87 let parts: Vec<&str> = credential.split('/').collect();
88 if parts.len() >= 4 {
89 let service = normalize_service_name(parts[3]).to_string();
90 if let Some(protocol) = rest_protocol_for(&service) {
91 return Some(DetectedRequest {
92 service,
93 action: String::new(),
94 protocol,
95 });
96 }
97 }
98 }
99 if query_params.contains_key("AWSAccessKeyId")
100 && query_params.contains_key("Signature")
101 && query_params.contains_key("Expires")
102 {
103 return Some(DetectedRequest {
104 service: "s3".to_string(),
105 action: String::new(),
106 protocol: AwsProtocol::Rest,
107 });
108 }
109 if let Some(host_info) = parse_routing_host_from_headers(headers) {
110 if let Some(protocol) = rest_protocol_for(&host_info.service) {
111 return Some(DetectedRequest {
112 service: host_info.service,
113 action: String::new(),
114 protocol,
115 });
116 }
117 }
118 None
119}
120
121pub fn detect_service(
123 headers: &HeaderMap,
124 query_params: &HashMap<String, String>,
125 body: &Bytes,
126) -> Option<DetectedRequest> {
127 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
129 return parse_amz_target(target);
130 }
131
132 if let Some(action) = query_params.get("Action") {
134 let service = extract_service_from_auth(headers)
135 .or_else(|| infer_service_from_action(action))
136 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
137 if let Some(service) = service {
138 return Some(DetectedRequest {
139 service,
140 action: action.clone(),
141 protocol: AwsProtocol::Query,
142 });
143 }
144 }
145
146 {
148 let form_params = decode_form_urlencoded(body);
149
150 if let Some(action) = form_params.get("Action") {
151 let service = extract_service_from_auth(headers)
152 .or_else(|| infer_service_from_action(action))
153 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
154 if let Some(service) = service {
155 return Some(DetectedRequest {
156 service,
157 action: action.clone(),
158 protocol: AwsProtocol::Query,
159 });
160 }
161 }
162 }
163
164 if let Some(service) = extract_service_from_auth(headers) {
166 if let Some(protocol) = rest_protocol_for(&service) {
167 return Some(DetectedRequest {
168 service,
169 action: String::new(), protocol,
171 });
172 }
173 }
174
175 if let Some(credential) = query_params.get("X-Amz-Credential") {
177 let parts: Vec<&str> = credential.split('/').collect();
179 if parts.len() >= 4 {
180 let service = normalize_service_name(parts[3]).to_string();
181 if let Some(protocol) = rest_protocol_for(&service) {
182 return Some(DetectedRequest {
183 service,
184 action: String::new(),
185 protocol,
186 });
187 }
188 }
189 }
190
191 if query_params.contains_key("AWSAccessKeyId")
195 && query_params.contains_key("Signature")
196 && query_params.contains_key("Expires")
197 {
198 return Some(DetectedRequest {
199 service: "s3".to_string(),
200 action: String::new(),
201 protocol: AwsProtocol::Rest,
202 });
203 }
204
205 if let Some(host_info) = parse_routing_host_from_headers(headers) {
209 if let Some(protocol) = rest_protocol_for(&host_info.service) {
210 return Some(DetectedRequest {
211 service: host_info.service,
212 action: String::new(),
213 protocol,
214 });
215 }
216 }
217
218 None
219}
220
221#[derive(Debug, Clone, PartialEq, Eq)]
230pub struct RoutingHost {
231 pub service: String,
232 pub region: String,
233 pub bucket: Option<String>,
235}
236
237const LOCALSTACK_SUFFIX: &str = ".localhost.localstack.cloud";
238const AWS_SUFFIX: &str = ".amazonaws.com";
239
240pub fn parse_routing_host(host: &str) -> Option<RoutingHost> {
244 let hostname = host.split(':').next()?;
245 if hostname.is_empty() {
246 return None;
247 }
248 let hostname = hostname.to_ascii_lowercase();
249 if let Some(prefix) = hostname.strip_suffix(LOCALSTACK_SUFFIX) {
250 return parse_localstack_prefix(prefix);
251 }
252 if hostname == "amazonaws.com" {
253 return None;
254 }
255 if let Some(prefix) = hostname.strip_suffix(AWS_SUFFIX) {
256 return parse_aws_prefix(prefix);
257 }
258 None
259}
260
261pub fn parse_routing_host_from_headers(headers: &HeaderMap) -> Option<RoutingHost> {
263 let host = headers.get("host")?.to_str().ok()?;
264 parse_routing_host(host)
265}
266
267fn parse_localstack_prefix(prefix: &str) -> Option<RoutingHost> {
268 if prefix.is_empty() {
269 return None;
270 }
271 let labels: Vec<&str> = prefix.split('.').collect();
272 if labels.iter().any(|l| l.is_empty()) {
273 return None;
274 }
275 match labels.len() {
276 2 => Some(RoutingHost {
277 service: labels[0].to_string(),
278 region: labels[1].to_string(),
279 bucket: None,
280 }),
281 n if n >= 3 && labels[n - 2] == "s3" => {
282 let bucket = labels[..n - 2].join(".");
283 Some(RoutingHost {
284 service: "s3".to_string(),
285 region: labels[n - 1].to_string(),
286 bucket: Some(bucket),
287 })
288 }
289 n if n >= 3 && labels[n - 2] == "s3-accesspoint" => {
290 let bucket = labels[..n - 2].join(".");
291 Some(RoutingHost {
292 service: "s3".to_string(),
293 region: labels[n - 1].to_string(),
294 bucket: Some(bucket),
295 })
296 }
297 n if n >= 3 && labels[n - 2] == "s3-control" => Some(RoutingHost {
298 service: "s3".to_string(),
299 region: labels[n - 1].to_string(),
300 bucket: None,
301 }),
302 _ => None,
303 }
304}
305
306fn parse_aws_prefix(prefix: &str) -> Option<RoutingHost> {
318 if prefix.is_empty() {
319 return None;
320 }
321 let labels: Vec<&str> = prefix.split('.').collect();
322 if labels.iter().any(|l| l.is_empty()) {
323 return None;
324 }
325 let last = *labels.last()?;
326
327 if let Some(region) = last.strip_prefix("s3-") {
330 if !region.is_empty() {
331 let bucket = if labels.len() >= 2 {
332 Some(labels[..labels.len() - 1].join("."))
333 } else {
334 None
335 };
336 return Some(RoutingHost {
337 service: "s3".to_string(),
338 region: region.to_string(),
339 bucket,
340 });
341 }
342 }
343
344 if last == "s3" {
348 if labels.len() == 1 {
349 return Some(RoutingHost {
350 service: "s3".to_string(),
351 region: "us-east-1".to_string(),
352 bucket: None,
353 });
354 }
355 return Some(RoutingHost {
356 service: "s3".to_string(),
357 region: "us-east-1".to_string(),
358 bucket: Some(labels[..labels.len() - 1].join(".")),
359 });
360 }
361
362 if last == "s3-accesspoint" {
365 if labels.len() == 2 {
366 return Some(RoutingHost {
367 service: "s3".to_string(),
368 region: labels[0].to_string(),
369 bucket: None,
370 });
371 }
372 if labels.len() >= 3 {
376 let bucket = labels[..labels.len() - 2].join(".");
377 return Some(RoutingHost {
378 service: "s3".to_string(),
379 region: labels[labels.len() - 1].to_string(),
380 bucket: Some(bucket),
381 });
382 }
383 }
384
385 if labels.len() >= 2 && labels[labels.len() - 2] == "s3-control" {
388 return Some(RoutingHost {
389 service: "s3".to_string(),
390 region: last.to_string(),
391 bucket: None,
392 });
393 }
394
395 match labels.len() {
396 2 => Some(RoutingHost {
399 service: labels[0].to_string(),
400 region: labels[1].to_string(),
401 bucket: None,
402 }),
403 n if n >= 3 && labels[n - 2] == "s3" => {
405 let bucket = labels[..n - 2].join(".");
406 Some(RoutingHost {
407 service: "s3".to_string(),
408 region: labels[n - 1].to_string(),
409 bucket: Some(bucket),
410 })
411 }
412 _ => None,
413 }
414}
415
416fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
419 let (prefix, action) = target.rsplit_once('.')?;
420
421 let service = match prefix {
422 "AWSEvents" => "events",
423 "AmazonSSM" => "ssm",
424 "AmazonSQS" => "sqs",
425 "AmazonSNS" => "sns",
426 "DynamoDB_20120810" => "dynamodb",
427 "DynamoDBStreams_20120810" => "dynamodbstreams",
428 "Logs_20140328" => "logs",
429 s if s.starts_with("secretsmanager") => "secretsmanager",
430 s if s.starts_with("TrentService") => "kms",
431 s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
432 s if s.starts_with("AWSCognitoIdentityService") => "cognito-identity",
433 s if s.starts_with("Kinesis_20131202") => "kinesis",
434 s if s.starts_with("AmazonEC2ContainerRegistry_V") => "ecr",
435 s if s.starts_with("AmazonEC2ContainerServiceV") => "ecs",
436 s if s.starts_with("AWSStepFunctions") => "states",
437 s if s.starts_with("AWSOrganizationsV") => "organizations",
438 "CertificateManager" => "acm",
439 "AnyScaleFrontendService" => "application-autoscaling",
440 "AWSWAF_20190729" => "wafv2",
443 "AmazonAthena" => "athena",
444 s if s.starts_with("Firehose_") => "firehose",
445 "AWSGlue" => "glue",
446 "CloudApiService" => "cloudcontrolapi",
447 "ResourceGroupsTaggingAPI_20170126" => "tagging",
448 "AmazonMemoryDB" => "memorydb",
449 "Route53AutoNaming_v20170314" => "servicediscovery",
452 s if s.starts_with("GraniteServiceVersion") => "monitoring",
458 _ => return None,
459 };
460
461 Some(DetectedRequest {
462 service: service.to_string(),
463 action: action.to_string(),
464 protocol: AwsProtocol::Json,
465 })
466}
467
468fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
470 if REST_XML_SERVICES.contains(&service) {
471 Some(AwsProtocol::Rest)
472 } else if REST_JSON_SERVICES.contains(&service) {
473 Some(AwsProtocol::RestJson)
474 } else {
475 None
476 }
477}
478
479fn infer_service_from_action(action: &str) -> Option<String> {
483 match action {
484 "AssumeRole"
485 | "AssumeRoleWithSAML"
486 | "AssumeRoleWithWebIdentity"
487 | "GetCallerIdentity"
488 | "GetSessionToken"
489 | "GetFederationToken"
490 | "GetAccessKeyInfo"
491 | "DecodeAuthorizationMessage" => Some("sts".to_string()),
492 "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
493 | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
494 | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
495 | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
496 "VerifyEmailIdentity"
498 | "VerifyDomainIdentity"
499 | "VerifyDomainDkim"
500 | "ListIdentities"
501 | "GetIdentityVerificationAttributes"
502 | "GetIdentityDkimAttributes"
503 | "DeleteIdentity"
504 | "SetIdentityDkimEnabled"
505 | "SetIdentityNotificationTopic"
506 | "SetIdentityFeedbackForwardingEnabled"
507 | "GetIdentityNotificationAttributes"
508 | "GetIdentityMailFromDomainAttributes"
509 | "SetIdentityMailFromDomain"
510 | "SendEmail"
511 | "SendRawEmail"
512 | "SendTemplatedEmail"
513 | "SendBulkTemplatedEmail"
514 | "CreateTemplate"
515 | "GetTemplate"
516 | "ListTemplates"
517 | "DeleteTemplate"
518 | "UpdateTemplate"
519 | "CreateConfigurationSet"
520 | "DeleteConfigurationSet"
521 | "DescribeConfigurationSet"
522 | "ListConfigurationSets"
523 | "CreateConfigurationSetEventDestination"
524 | "UpdateConfigurationSetEventDestination"
525 | "DeleteConfigurationSetEventDestination"
526 | "GetSendQuota"
527 | "GetSendStatistics"
528 | "GetAccountSendingEnabled"
529 | "CreateReceiptRuleSet"
530 | "DeleteReceiptRuleSet"
531 | "DescribeReceiptRuleSet"
532 | "ListReceiptRuleSets"
533 | "CloneReceiptRuleSet"
534 | "SetActiveReceiptRuleSet"
535 | "ReorderReceiptRuleSet"
536 | "CreateReceiptRule"
537 | "DeleteReceiptRule"
538 | "DescribeReceiptRule"
539 | "UpdateReceiptRule"
540 | "CreateReceiptFilter"
541 | "DeleteReceiptFilter"
542 | "ListReceiptFilters" => Some("ses".to_string()),
543 "ConfirmSubscription" | "Unsubscribe" => Some("sns".to_string()),
547 _ => None,
548 }
549}
550
551fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
553 let auth = headers.get("authorization")?.to_str().ok()?;
554 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
555 Some(normalize_service_name(&info.service).to_string())
556}
557
558fn normalize_service_name(service: &str) -> &str {
570 match service {
571 "bedrock-runtime" => "bedrock",
572 "apigatewayv2" => "apigateway",
580 other => other,
581 }
582}
583
584pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
586 decode_form_urlencoded(body)
587}
588
589pub fn flatten_json_to_query(body: &Bytes) -> HashMap<String, String> {
606 let mut out = HashMap::new();
607 let Ok(value) = serde_json::from_slice::<serde_json::Value>(body) else {
608 return out;
609 };
610 if value.is_object() {
611 flatten_json_value("", &value, &mut out);
612 }
613 out
614}
615
616fn flatten_json_value(prefix: &str, value: &serde_json::Value, out: &mut HashMap<String, String>) {
617 match value {
618 serde_json::Value::Object(map) => {
619 for (k, v) in map {
620 let child = if prefix.is_empty() {
621 k.clone()
622 } else {
623 format!("{prefix}.{k}")
624 };
625 flatten_json_value(&child, v, out);
626 }
627 }
628 serde_json::Value::Array(items) => {
629 for (i, v) in items.iter().enumerate() {
630 let child = format!("{prefix}.member.{}", i + 1);
631 flatten_json_value(&child, v, out);
632 }
633 }
634 serde_json::Value::Null => {}
635 serde_json::Value::String(s) => {
636 out.insert(prefix.to_string(), s.clone());
637 }
638 serde_json::Value::Bool(b) => {
639 out.insert(prefix.to_string(), b.to_string());
640 }
641 serde_json::Value::Number(n) => {
642 out.insert(prefix.to_string(), n.to_string());
643 }
644 }
645}
646
647fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
648 let s = std::str::from_utf8(input).unwrap_or("");
649 let mut result = HashMap::new();
650 for pair in s.split('&') {
651 if pair.is_empty() {
652 continue;
653 }
654 let (key, value) = match pair.find('=') {
655 Some(pos) => (&pair[..pos], &pair[pos + 1..]),
656 None => (pair, ""),
657 };
658 result.insert(url_decode(key), url_decode(value));
659 }
660 result
661}
662
663fn url_decode(input: &str) -> String {
664 let mut result = String::with_capacity(input.len());
665 let mut bytes = input.bytes();
666 while let Some(b) = bytes.next() {
667 match b {
668 b'+' => result.push(' '),
669 b'%' => {
670 let high = bytes.next().and_then(from_hex);
671 let low = bytes.next().and_then(from_hex);
672 if let (Some(h), Some(l)) = (high, low) {
673 result.push((h << 4 | l) as char);
674 }
675 }
676 _ => result.push(b as char),
677 }
678 }
679 result
680}
681
682fn from_hex(b: u8) -> Option<u8> {
683 match b {
684 b'0'..=b'9' => Some(b - b'0'),
685 b'a'..=b'f' => Some(b - b'a' + 10),
686 b'A'..=b'F' => Some(b - b'A' + 10),
687 _ => None,
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694
695 #[test]
696 fn parse_amz_target_events() {
697 let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
698 assert_eq!(result.service, "events");
699 assert_eq!(result.action, "PutEvents");
700 assert_eq!(result.protocol, AwsProtocol::Json);
701 }
702
703 #[test]
704 fn parse_amz_target_ssm() {
705 let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
706 assert_eq!(result.service, "ssm");
707 assert_eq!(result.action, "GetParameter");
708 }
709
710 #[test]
711 fn parse_amz_target_kinesis() {
712 let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
713 assert_eq!(result.service, "kinesis");
714 assert_eq!(result.action, "ListStreams");
715 assert_eq!(result.protocol, AwsProtocol::Json);
716 }
717
718 #[test]
719 fn parse_query_body_basic() {
720 let body = Bytes::from(
721 "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
722 );
723 let params = parse_query_body(&body);
724 assert_eq!(params.get("Action").unwrap(), "SendMessage");
725 assert_eq!(params.get("MessageBody").unwrap(), "hello");
726 }
727
728 #[test]
729 fn parse_query_body_empty_returns_empty_map() {
730 let body = Bytes::from("");
731 let params = parse_query_body(&body);
732 assert!(params.is_empty());
733 }
734
735 #[test]
736 fn parse_query_body_duplicate_keys_last_wins() {
737 let body = Bytes::from("key=a&key=b");
738 let params = parse_query_body(&body);
739 assert_eq!(params.get("key").unwrap(), "b");
740 }
741
742 #[test]
743 fn parse_query_body_single_key() {
744 let body = Bytes::from("key=value");
745 let params = parse_query_body(&body);
746 assert_eq!(params.get("key").unwrap(), "value");
747 }
748
749 #[test]
750 fn parse_amz_target_ecs() {
751 let result = parse_amz_target("AmazonEC2ContainerServiceV20141113.ListClusters").unwrap();
752 assert_eq!(result.service, "ecs");
753 assert_eq!(result.action, "ListClusters");
754 assert_eq!(result.protocol, AwsProtocol::Json);
755 }
756
757 #[test]
758 fn parse_amz_target_invalid_returns_none() {
759 assert!(parse_amz_target("NoDotHere").is_none());
760 assert!(parse_amz_target("").is_none());
761 }
762
763 #[test]
764 fn parse_amz_target_cloudwatch_json() {
765 let result = parse_amz_target("GraniteServiceVersion20100801.PutMetricData").unwrap();
767 assert_eq!(result.service, "monitoring");
768 assert_eq!(result.action, "PutMetricData");
769 assert_eq!(result.protocol, AwsProtocol::Json);
770 }
771
772 #[test]
773 fn flatten_json_to_query_nested() {
774 let body = Bytes::from(
775 serde_json::json!({
776 "Namespace": "MyApp",
777 "MetricData": [{
778 "MetricName": "Latency",
779 "Value": 12.5,
780 "StatisticValues": {"SampleCount": 3, "Sum": 10},
781 "Dimensions": [{"Name": "Endpoint", "Value": "/api"}]
782 }]
783 })
784 .to_string(),
785 );
786 let flat = flatten_json_to_query(&body);
787 assert_eq!(flat.get("Namespace").unwrap(), "MyApp");
788 assert_eq!(
789 flat.get("MetricData.member.1.MetricName").unwrap(),
790 "Latency"
791 );
792 assert_eq!(flat.get("MetricData.member.1.Value").unwrap(), "12.5");
793 assert_eq!(
794 flat.get("MetricData.member.1.StatisticValues.SampleCount")
795 .unwrap(),
796 "3"
797 );
798 assert_eq!(
799 flat.get("MetricData.member.1.Dimensions.member.1.Name")
800 .unwrap(),
801 "Endpoint"
802 );
803 assert_eq!(
804 flat.get("MetricData.member.1.Dimensions.member.1.Value")
805 .unwrap(),
806 "/api"
807 );
808 }
809
810 #[test]
811 fn flatten_json_to_query_non_object_is_empty() {
812 assert!(flatten_json_to_query(&Bytes::from_static(b"[]")).is_empty());
813 assert!(flatten_json_to_query(&Bytes::from_static(b"not json")).is_empty());
814 }
815
816 #[test]
817 fn parse_amz_target_various_prefixes() {
818 assert_eq!(
819 parse_amz_target("AmazonSQS.SendMessage").unwrap().service,
820 "sqs"
821 );
822 assert_eq!(
823 parse_amz_target("AmazonSNS.Publish").unwrap().service,
824 "sns"
825 );
826 assert_eq!(
827 parse_amz_target("DynamoDB_20120810.GetItem")
828 .unwrap()
829 .service,
830 "dynamodb"
831 );
832 assert_eq!(
833 parse_amz_target("Logs_20140328.PutLogEvents")
834 .unwrap()
835 .service,
836 "logs"
837 );
838 assert_eq!(
839 parse_amz_target("secretsmanager.GetSecretValue")
840 .unwrap()
841 .service,
842 "secretsmanager"
843 );
844 assert_eq!(
845 parse_amz_target("TrentService.Encrypt").unwrap().service,
846 "kms"
847 );
848 assert_eq!(
849 parse_amz_target("AWSCognitoIdentityProviderService.InitiateAuth")
850 .unwrap()
851 .service,
852 "cognito-idp"
853 );
854 assert_eq!(
855 parse_amz_target("AWSStepFunctions.StartExecution")
856 .unwrap()
857 .service,
858 "states"
859 );
860 assert_eq!(
861 parse_amz_target("AWSOrganizationsV20161128.CreateOrganization")
862 .unwrap()
863 .service,
864 "organizations"
865 );
866 assert!(parse_amz_target("UnknownServicePrefix.Action").is_none());
867 }
868
869 #[test]
870 fn infer_service_from_action_maps_sts() {
871 assert_eq!(
872 infer_service_from_action("AssumeRole").as_deref(),
873 Some("sts")
874 );
875 assert_eq!(
876 infer_service_from_action("GetCallerIdentity").as_deref(),
877 Some("sts")
878 );
879 }
880
881 #[test]
882 fn infer_service_from_action_maps_iam() {
883 assert_eq!(
884 infer_service_from_action("CreateUser").as_deref(),
885 Some("iam")
886 );
887 assert_eq!(
888 infer_service_from_action("ListRoles").as_deref(),
889 Some("iam")
890 );
891 }
892
893 #[test]
894 fn infer_service_from_action_maps_ses() {
895 assert_eq!(
896 infer_service_from_action("SendEmail").as_deref(),
897 Some("ses")
898 );
899 assert_eq!(
900 infer_service_from_action("ListIdentities").as_deref(),
901 Some("ses")
902 );
903 }
904
905 #[test]
906 fn infer_service_from_action_maps_sns_confirmation_flow() {
907 assert_eq!(
910 infer_service_from_action("ConfirmSubscription").as_deref(),
911 Some("sns")
912 );
913 assert_eq!(
914 infer_service_from_action("Unsubscribe").as_deref(),
915 Some("sns")
916 );
917 }
918
919 #[test]
920 fn detect_service_routes_unsigned_confirm_subscription_to_sns() {
921 let mut headers = HeaderMap::new();
924 headers.insert("host", "localhost:4566".parse().unwrap());
925 let mut query_params = HashMap::new();
926 query_params.insert("Action".to_string(), "ConfirmSubscription".to_string());
927 query_params.insert(
928 "TopicArn".to_string(),
929 "arn:aws:sns:us-east-1:000000000000:t".to_string(),
930 );
931 query_params.insert("Token".to_string(), "abc123".to_string());
932
933 let detected = detect_service(&headers, &query_params, &Bytes::new())
934 .expect("ConfirmSubscription must route to a service");
935 assert_eq!(detected.service, "sns");
936 assert_eq!(detected.action, "ConfirmSubscription");
937 assert_eq!(detected.protocol, AwsProtocol::Query);
938 }
939
940 #[test]
941 fn infer_service_from_action_unknown_returns_none() {
942 assert!(infer_service_from_action("NotARealAction").is_none());
943 }
944
945 #[test]
946 fn rest_protocol_for_returns_none_for_non_rest_service() {
947 assert!(rest_protocol_for("sqs").is_none());
948 }
949
950 #[test]
951 fn url_decode_handles_percent_and_plus() {
952 assert_eq!(url_decode("hello+world"), "hello world");
953 assert_eq!(url_decode("hello%20world"), "hello world");
954 assert_eq!(url_decode("100%25"), "100%");
955 }
956
957 #[test]
958 fn url_decode_ignores_malformed_percent() {
959 assert_eq!(url_decode("%ZZ"), "");
960 }
961
962 #[test]
963 fn from_hex_valid_digits() {
964 assert_eq!(from_hex(b'0'), Some(0));
965 assert_eq!(from_hex(b'9'), Some(9));
966 assert_eq!(from_hex(b'a'), Some(10));
967 assert_eq!(from_hex(b'F'), Some(15));
968 }
969
970 #[test]
971 fn from_hex_invalid_returns_none() {
972 assert!(from_hex(b'g').is_none());
973 assert!(from_hex(b' ').is_none());
974 }
975
976 #[test]
977 fn detect_service_via_amz_target() {
978 let mut headers = HeaderMap::new();
979 headers.insert("x-amz-target", "AmazonSSM.GetParameter".parse().unwrap());
980 let query = HashMap::new();
981 let body = Bytes::new();
982 let detected = detect_service(&headers, &query, &body).unwrap();
983 assert_eq!(detected.service, "ssm");
984 assert_eq!(detected.action, "GetParameter");
985 }
986
987 #[test]
988 fn detect_service_via_query_action_with_inferred_service() {
989 let headers = HeaderMap::new();
990 let mut query = HashMap::new();
991 query.insert("Action".to_string(), "AssumeRole".to_string());
992 let body = Bytes::new();
993 let detected = detect_service(&headers, &query, &body).unwrap();
994 assert_eq!(detected.service, "sts");
995 assert_eq!(detected.action, "AssumeRole");
996 assert_eq!(detected.protocol, AwsProtocol::Query);
997 }
998
999 #[test]
1000 fn detect_service_via_form_body() {
1001 let headers = HeaderMap::new();
1002 let query = HashMap::new();
1003 let body = Bytes::from("Action=SendEmail&Source=x%40y.com");
1004 let detected = detect_service(&headers, &query, &body).unwrap();
1005 assert_eq!(detected.service, "ses");
1006 assert_eq!(detected.action, "SendEmail");
1007 }
1008
1009 #[test]
1010 fn detect_service_via_sigv2_presigned() {
1011 let headers = HeaderMap::new();
1012 let mut query = HashMap::new();
1013 query.insert("AWSAccessKeyId".to_string(), "AKID".to_string());
1014 query.insert("Signature".to_string(), "sig".to_string());
1015 query.insert("Expires".to_string(), "1234567890".to_string());
1016 let body = Bytes::new();
1017 let detected = detect_service(&headers, &query, &body).unwrap();
1018 assert_eq!(detected.service, "s3");
1019 assert_eq!(detected.protocol, AwsProtocol::Rest);
1020 }
1021
1022 #[test]
1023 fn detect_service_via_sigv4_presigned_credential() {
1024 let headers = HeaderMap::new();
1025 let mut query = HashMap::new();
1026 query.insert(
1027 "X-Amz-Credential".to_string(),
1028 "AKID/20240101/us-east-1/s3/aws4_request".to_string(),
1029 );
1030 let body = Bytes::new();
1031 let detected = detect_service(&headers, &query, &body).unwrap();
1032 assert_eq!(detected.service, "s3");
1033 assert_eq!(detected.protocol, AwsProtocol::Rest);
1034 }
1035
1036 #[test]
1037 fn detect_service_unknown_returns_none() {
1038 let headers = HeaderMap::new();
1039 let query = HashMap::new();
1040 let body = Bytes::new();
1041 assert!(detect_service(&headers, &query, &body).is_none());
1042 }
1043
1044 #[test]
1045 fn normalize_service_name_aliases_apigatewayv2_to_apigateway() {
1046 assert_eq!(normalize_service_name("apigatewayv2"), "apigateway");
1051 }
1052
1053 #[test]
1054 fn normalize_service_name_aliases_bedrock_runtime_to_bedrock() {
1055 assert_eq!(normalize_service_name("bedrock-runtime"), "bedrock");
1060 }
1061
1062 #[test]
1063 fn normalize_service_name_passes_through_unaliased_services() {
1064 assert_eq!(normalize_service_name("bedrock"), "bedrock");
1068 assert_eq!(normalize_service_name("s3"), "s3");
1069 assert_eq!(normalize_service_name("lambda"), "lambda");
1070 assert_eq!(normalize_service_name(""), "");
1071 assert_eq!(
1072 normalize_service_name("unknown-future-service"),
1073 "unknown-future-service"
1074 );
1075 }
1076
1077 #[test]
1078 fn detect_service_via_authorization_header_normalizes_bedrock_runtime() {
1079 let mut headers = HeaderMap::new();
1084 headers.insert(
1085 "authorization",
1086 "AWS4-HMAC-SHA256 \
1087 Credential=AKID/20240101/us-east-1/bedrock-runtime/aws4_request, \
1088 SignedHeaders=host, Signature=abc"
1089 .parse()
1090 .unwrap(),
1091 );
1092 let query = HashMap::new();
1093 let body = Bytes::new();
1094 let detected = detect_service(&headers, &query, &body).unwrap();
1095 assert_eq!(detected.service, "bedrock");
1096 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1097 }
1098
1099 #[test]
1100 fn detect_service_via_sigv4_presigned_credential_normalizes_bedrock_runtime() {
1101 let headers = HeaderMap::new();
1105 let mut query = HashMap::new();
1106 query.insert(
1107 "X-Amz-Credential".to_string(),
1108 "AKID/20240101/us-east-1/bedrock-runtime/aws4_request".to_string(),
1109 );
1110 let body = Bytes::new();
1111 let detected = detect_service(&headers, &query, &body).unwrap();
1112 assert_eq!(detected.service, "bedrock");
1113 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1114 }
1115
1116 #[test]
1117 fn parse_routing_host_localstack_basic() {
1118 let h = parse_routing_host("sqs.us-east-1.localhost.localstack.cloud").unwrap();
1119 assert_eq!(h.service, "sqs");
1120 assert_eq!(h.region, "us-east-1");
1121 assert!(h.bucket.is_none());
1122 }
1123
1124 #[test]
1125 fn parse_routing_host_localstack_with_port() {
1126 let h = parse_routing_host("lambda.eu-west-1.localhost.localstack.cloud:4566").unwrap();
1127 assert_eq!(h.service, "lambda");
1128 assert_eq!(h.region, "eu-west-1");
1129 assert!(h.bucket.is_none());
1130 }
1131
1132 #[test]
1133 fn parse_routing_host_case_insensitive() {
1134 let h = parse_routing_host("SQS.US-EAST-1.LOCALHOST.LOCALSTACK.CLOUD:4566").unwrap();
1135 assert_eq!(h.service, "sqs");
1136 assert_eq!(h.region, "us-east-1");
1137
1138 let h = parse_routing_host("LAMBDA.US-EAST-1.AMAZONAWS.COM").unwrap();
1139 assert_eq!(h.service, "lambda");
1140 assert_eq!(h.region, "us-east-1");
1141 }
1142
1143 #[test]
1144 fn parse_routing_host_localstack_s3_virtual_hosted() {
1145 let h =
1146 parse_routing_host("my-bucket.s3.us-east-1.localhost.localstack.cloud:4566").unwrap();
1147 assert_eq!(h.service, "s3");
1148 assert_eq!(h.region, "us-east-1");
1149 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1150 }
1151
1152 #[test]
1153 fn parse_routing_host_localstack_s3_vhost_bucket_with_dots() {
1154 let h = parse_routing_host("a.b.c.s3.us-east-1.localhost.localstack.cloud").unwrap();
1155 assert_eq!(h.service, "s3");
1156 assert_eq!(h.region, "us-east-1");
1157 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1158 }
1159
1160 #[test]
1161 fn parse_routing_host_aws_service_region() {
1162 let h = parse_routing_host("sqs.us-east-1.amazonaws.com").unwrap();
1163 assert_eq!(h.service, "sqs");
1164 assert_eq!(h.region, "us-east-1");
1165 assert!(h.bucket.is_none());
1166
1167 let h = parse_routing_host("dynamodb.eu-west-2.amazonaws.com:443").unwrap();
1168 assert_eq!(h.service, "dynamodb");
1169 assert_eq!(h.region, "eu-west-2");
1170 }
1171
1172 #[test]
1173 fn parse_routing_host_aws_s3_path_style_modern() {
1174 let h = parse_routing_host("s3.us-east-1.amazonaws.com").unwrap();
1175 assert_eq!(h.service, "s3");
1176 assert_eq!(h.region, "us-east-1");
1177 assert!(h.bucket.is_none());
1178 }
1179
1180 #[test]
1181 fn parse_routing_host_aws_s3_virtual_hosted_modern() {
1182 let h = parse_routing_host("my-bucket.s3.us-east-1.amazonaws.com").unwrap();
1183 assert_eq!(h.service, "s3");
1184 assert_eq!(h.region, "us-east-1");
1185 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1186 }
1187
1188 #[test]
1189 fn parse_routing_host_aws_s3_vhost_bucket_with_dots() {
1190 let h = parse_routing_host("a.b.c.s3.us-east-1.amazonaws.com").unwrap();
1191 assert_eq!(h.service, "s3");
1192 assert_eq!(h.region, "us-east-1");
1193 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1194 }
1195
1196 #[test]
1197 fn parse_routing_host_aws_s3_legacy_global() {
1198 let h = parse_routing_host("s3.amazonaws.com").unwrap();
1201 assert_eq!(h.service, "s3");
1202 assert_eq!(h.region, "us-east-1");
1203 assert!(h.bucket.is_none());
1204
1205 let h = parse_routing_host("my-bucket.s3.amazonaws.com").unwrap();
1206 assert_eq!(h.service, "s3");
1207 assert_eq!(h.region, "us-east-1");
1208 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1209 }
1210
1211 #[test]
1212 fn parse_routing_host_aws_s3_legacy_global_dotted_bucket() {
1213 let h = parse_routing_host("a.b.c.s3.amazonaws.com").unwrap();
1216 assert_eq!(h.service, "s3");
1217 assert_eq!(h.region, "us-east-1");
1218 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1219 }
1220
1221 #[test]
1222 fn parse_routing_host_aws_s3_dash_separated() {
1223 let h = parse_routing_host("s3-us-west-2.amazonaws.com").unwrap();
1225 assert_eq!(h.service, "s3");
1226 assert_eq!(h.region, "us-west-2");
1227 assert!(h.bucket.is_none());
1228
1229 let h = parse_routing_host("my-bucket.s3-us-west-2.amazonaws.com").unwrap();
1230 assert_eq!(h.service, "s3");
1231 assert_eq!(h.region, "us-west-2");
1232 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1233 }
1234
1235 #[test]
1236 fn parse_routing_host_rejects_plain_localhost() {
1237 assert!(parse_routing_host("localhost:4566").is_none());
1238 assert!(parse_routing_host("127.0.0.1:4566").is_none());
1239 }
1240
1241 #[test]
1242 fn parse_routing_host_rejects_unknown_suffix() {
1243 assert!(parse_routing_host("sqs.us-east-1.example.com").is_none());
1244 assert!(parse_routing_host("s3.us-east-1.aws").is_none());
1245 }
1246
1247 #[test]
1248 fn parse_routing_host_empty_and_malformed_rejected() {
1249 assert!(parse_routing_host("").is_none());
1250 assert!(parse_routing_host(".localhost.localstack.cloud").is_none());
1251 assert!(parse_routing_host("..localhost.localstack.cloud").is_none());
1252 assert!(parse_routing_host("sqs.localhost.localstack.cloud").is_none());
1253 assert!(parse_routing_host("foo.bar.baz.localhost.localstack.cloud").is_none());
1254 assert!(parse_routing_host(".amazonaws.com").is_none());
1255 assert!(parse_routing_host("amazonaws.com").is_none());
1256 }
1257
1258 #[test]
1259 fn parse_routing_host_bare_s3_accesspoint_does_not_panic() {
1260 assert!(parse_routing_host("s3-accesspoint").is_none());
1264 }
1265
1266 #[test]
1267 fn detect_service_via_host_for_rest_service() {
1268 let mut headers = HeaderMap::new();
1269 headers.insert(
1270 "host",
1271 "s3.us-east-1.localhost.localstack.cloud:4566"
1272 .parse()
1273 .unwrap(),
1274 );
1275 let query = HashMap::new();
1276 let body = Bytes::new();
1277 let detected = detect_service(&headers, &query, &body).unwrap();
1278 assert_eq!(detected.service, "s3");
1279 assert_eq!(detected.protocol, AwsProtocol::Rest);
1280 }
1281
1282 #[test]
1283 fn detect_service_via_host_for_rest_json_service() {
1284 let mut headers = HeaderMap::new();
1285 headers.insert(
1286 "host",
1287 "lambda.us-east-1.localhost.localstack.cloud:4566"
1288 .parse()
1289 .unwrap(),
1290 );
1291 let query = HashMap::new();
1292 let body = Bytes::new();
1293 let detected = detect_service(&headers, &query, &body).unwrap();
1294 assert_eq!(detected.service, "lambda");
1295 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1296 }
1297
1298 #[test]
1299 fn detect_service_via_host_plus_query_action() {
1300 let mut headers = HeaderMap::new();
1301 headers.insert(
1302 "host",
1303 "sqs.us-east-1.localhost.localstack.cloud:4566"
1304 .parse()
1305 .unwrap(),
1306 );
1307 let mut query = HashMap::new();
1308 query.insert("Action".to_string(), "ListQueues".to_string());
1309 let body = Bytes::new();
1310 let detected = detect_service(&headers, &query, &body).unwrap();
1311 assert_eq!(detected.service, "sqs");
1312 assert_eq!(detected.action, "ListQueues");
1313 assert_eq!(detected.protocol, AwsProtocol::Query);
1314 }
1315
1316 #[test]
1317 fn detect_service_sigv4_wins_over_host() {
1318 let mut headers = HeaderMap::new();
1319 headers.insert(
1320 "authorization",
1321 "AWS4-HMAC-SHA256 Credential=AKID/20240101/us-east-1/s3/aws4_request, \
1322 SignedHeaders=host, Signature=abc"
1323 .parse()
1324 .unwrap(),
1325 );
1326 headers.insert(
1327 "host",
1328 "lambda.us-east-1.localhost.localstack.cloud:4566"
1329 .parse()
1330 .unwrap(),
1331 );
1332 let query = HashMap::new();
1333 let body = Bytes::new();
1334 let detected = detect_service(&headers, &query, &body).unwrap();
1335 assert_eq!(detected.service, "s3");
1337 assert_eq!(detected.protocol, AwsProtocol::Rest);
1338 }
1339
1340 #[test]
1341 fn detect_service_host_for_virtual_hosted_s3() {
1342 let mut headers = HeaderMap::new();
1343 headers.insert(
1344 "host",
1345 "my-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1346 .parse()
1347 .unwrap(),
1348 );
1349 let query = HashMap::new();
1350 let body = Bytes::new();
1351 let detected = detect_service(&headers, &query, &body).unwrap();
1352 assert_eq!(detected.service, "s3");
1353 assert_eq!(detected.protocol, AwsProtocol::Rest);
1354 }
1355}