1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8 Query,
11 Json,
14 Rest,
17 RestJson,
20}
21
22const REST_XML_SERVICES: &[&str] = &["s3", "cloudfront", "route53"];
24
25const REST_JSON_SERVICES: &[&str] = &[
27 "lambda",
28 "ses",
29 "apigateway",
30 "bedrock",
31 "bedrock-agent",
32 "bedrock-agent-runtime",
33 "scheduler",
34 "batch",
35 "pipes",
36 "rds-data",
37 "dsql",
38 "resource-groups",
39 "eks",
40 "glacier",
41 "backup",
42 "ram",
45 "es",
49 "account",
50 "appconfig",
54];
55
56#[derive(Debug, Clone)]
58pub struct DetectedRequest {
59 pub service: String,
60 pub action: String,
61 pub protocol: AwsProtocol,
62}
63
64pub fn detect_service_headers_only(
71 headers: &HeaderMap,
72 query_params: &HashMap<String, String>,
73) -> Option<DetectedRequest> {
74 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
76 return parse_amz_target(target);
77 }
78 if let Some(action) = query_params.get("Action") {
79 let service = extract_service_from_auth(headers)
80 .or_else(|| infer_service_from_action(action))
81 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
82 if let Some(service) = service {
83 return Some(DetectedRequest {
84 service,
85 action: action.clone(),
86 protocol: AwsProtocol::Query,
87 });
88 }
89 }
90 if let Some(service) = extract_service_from_auth(headers) {
91 if let Some(protocol) = rest_protocol_for(&service) {
92 return Some(DetectedRequest {
93 service,
94 action: String::new(),
95 protocol,
96 });
97 }
98 }
99 if let Some(credential) = query_params.get("X-Amz-Credential") {
100 let parts: Vec<&str> = credential.split('/').collect();
101 if parts.len() >= 4 {
102 let service = normalize_service_name(parts[3]).to_string();
103 if let Some(protocol) = rest_protocol_for(&service) {
104 return Some(DetectedRequest {
105 service,
106 action: String::new(),
107 protocol,
108 });
109 }
110 }
111 }
112 if query_params.contains_key("AWSAccessKeyId")
113 && query_params.contains_key("Signature")
114 && query_params.contains_key("Expires")
115 {
116 return Some(DetectedRequest {
117 service: "s3".to_string(),
118 action: String::new(),
119 protocol: AwsProtocol::Rest,
120 });
121 }
122 if let Some(host_info) = parse_routing_host_from_headers(headers) {
123 if let Some(protocol) = rest_protocol_for(&host_info.service) {
124 return Some(DetectedRequest {
125 service: host_info.service,
126 action: String::new(),
127 protocol,
128 });
129 }
130 }
131 None
132}
133
134pub fn detect_service(
136 headers: &HeaderMap,
137 query_params: &HashMap<String, String>,
138 body: &Bytes,
139) -> Option<DetectedRequest> {
140 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
142 return parse_amz_target(target);
143 }
144
145 if let Some(action) = query_params.get("Action") {
147 let service = extract_service_from_auth(headers)
148 .or_else(|| infer_service_from_action(action))
149 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
150 if let Some(service) = service {
151 return Some(DetectedRequest {
152 service,
153 action: action.clone(),
154 protocol: AwsProtocol::Query,
155 });
156 }
157 }
158
159 {
161 let form_params = decode_form_urlencoded(body);
162
163 if let Some(action) = form_params.get("Action") {
164 let service = extract_service_from_auth(headers)
165 .or_else(|| infer_service_from_action(action))
166 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
167 if let Some(service) = service {
168 return Some(DetectedRequest {
169 service,
170 action: action.clone(),
171 protocol: AwsProtocol::Query,
172 });
173 }
174 }
175 }
176
177 if let Some(service) = extract_service_from_auth(headers) {
179 if let Some(protocol) = rest_protocol_for(&service) {
180 return Some(DetectedRequest {
181 service,
182 action: String::new(), protocol,
184 });
185 }
186 }
187
188 if let Some(credential) = query_params.get("X-Amz-Credential") {
190 let parts: Vec<&str> = credential.split('/').collect();
192 if parts.len() >= 4 {
193 let service = normalize_service_name(parts[3]).to_string();
194 if let Some(protocol) = rest_protocol_for(&service) {
195 return Some(DetectedRequest {
196 service,
197 action: String::new(),
198 protocol,
199 });
200 }
201 }
202 }
203
204 if query_params.contains_key("AWSAccessKeyId")
208 && query_params.contains_key("Signature")
209 && query_params.contains_key("Expires")
210 {
211 return Some(DetectedRequest {
212 service: "s3".to_string(),
213 action: String::new(),
214 protocol: AwsProtocol::Rest,
215 });
216 }
217
218 if let Some(host_info) = parse_routing_host_from_headers(headers) {
222 if let Some(protocol) = rest_protocol_for(&host_info.service) {
223 return Some(DetectedRequest {
224 service: host_info.service,
225 action: String::new(),
226 protocol,
227 });
228 }
229 }
230
231 None
232}
233
234#[derive(Debug, Clone, PartialEq, Eq)]
243pub struct RoutingHost {
244 pub service: String,
245 pub region: String,
246 pub bucket: Option<String>,
248}
249
250const LOCALSTACK_SUFFIX: &str = ".localhost.localstack.cloud";
251const AWS_SUFFIX: &str = ".amazonaws.com";
252
253pub fn parse_routing_host(host: &str) -> Option<RoutingHost> {
257 let hostname = host.split(':').next()?;
258 if hostname.is_empty() {
259 return None;
260 }
261 let hostname = hostname.to_ascii_lowercase();
262 if let Some(prefix) = hostname.strip_suffix(LOCALSTACK_SUFFIX) {
263 return parse_localstack_prefix(prefix);
264 }
265 if hostname == "amazonaws.com" {
266 return None;
267 }
268 if let Some(prefix) = hostname.strip_suffix(AWS_SUFFIX) {
269 return parse_aws_prefix(prefix);
270 }
271 None
272}
273
274pub fn parse_routing_host_from_headers(headers: &HeaderMap) -> Option<RoutingHost> {
276 let host = headers.get("host")?.to_str().ok()?;
277 parse_routing_host(host)
278}
279
280fn parse_localstack_prefix(prefix: &str) -> Option<RoutingHost> {
281 if prefix.is_empty() {
282 return None;
283 }
284 let labels: Vec<&str> = prefix.split('.').collect();
285 if labels.iter().any(|l| l.is_empty()) {
286 return None;
287 }
288 match labels.len() {
289 2 => Some(RoutingHost {
290 service: labels[0].to_string(),
291 region: labels[1].to_string(),
292 bucket: None,
293 }),
294 n if n >= 3 && labels[n - 2] == "s3" => {
295 let bucket = labels[..n - 2].join(".");
296 Some(RoutingHost {
297 service: "s3".to_string(),
298 region: labels[n - 1].to_string(),
299 bucket: Some(bucket),
300 })
301 }
302 n if n >= 3 && labels[n - 2] == "s3-accesspoint" => {
303 let bucket = labels[..n - 2].join(".");
304 Some(RoutingHost {
305 service: "s3".to_string(),
306 region: labels[n - 1].to_string(),
307 bucket: Some(bucket),
308 })
309 }
310 n if n >= 3 && labels[n - 2] == "s3-control" => Some(RoutingHost {
311 service: "s3".to_string(),
312 region: labels[n - 1].to_string(),
313 bucket: None,
314 }),
315 _ => None,
316 }
317}
318
319fn parse_aws_prefix(prefix: &str) -> Option<RoutingHost> {
331 if prefix.is_empty() {
332 return None;
333 }
334 let labels: Vec<&str> = prefix.split('.').collect();
335 if labels.iter().any(|l| l.is_empty()) {
336 return None;
337 }
338 let last = *labels.last()?;
339
340 if let Some(region) = last.strip_prefix("s3-") {
343 if !region.is_empty() {
344 let bucket = if labels.len() >= 2 {
345 Some(labels[..labels.len() - 1].join("."))
346 } else {
347 None
348 };
349 return Some(RoutingHost {
350 service: "s3".to_string(),
351 region: region.to_string(),
352 bucket,
353 });
354 }
355 }
356
357 if last == "s3" {
361 if labels.len() == 1 {
362 return Some(RoutingHost {
363 service: "s3".to_string(),
364 region: "us-east-1".to_string(),
365 bucket: None,
366 });
367 }
368 return Some(RoutingHost {
369 service: "s3".to_string(),
370 region: "us-east-1".to_string(),
371 bucket: Some(labels[..labels.len() - 1].join(".")),
372 });
373 }
374
375 if last == "s3-accesspoint" {
378 if labels.len() == 2 {
379 return Some(RoutingHost {
380 service: "s3".to_string(),
381 region: labels[0].to_string(),
382 bucket: None,
383 });
384 }
385 if labels.len() >= 3 {
389 let bucket = labels[..labels.len() - 2].join(".");
390 return Some(RoutingHost {
391 service: "s3".to_string(),
392 region: labels[labels.len() - 1].to_string(),
393 bucket: Some(bucket),
394 });
395 }
396 }
397
398 if labels.len() >= 2 && labels[labels.len() - 2] == "s3-control" {
401 return Some(RoutingHost {
402 service: "s3".to_string(),
403 region: last.to_string(),
404 bucket: None,
405 });
406 }
407
408 match labels.len() {
409 2 => Some(RoutingHost {
412 service: labels[0].to_string(),
413 region: labels[1].to_string(),
414 bucket: None,
415 }),
416 n if n >= 3 && labels[n - 2] == "s3" => {
418 let bucket = labels[..n - 2].join(".");
419 Some(RoutingHost {
420 service: "s3".to_string(),
421 region: labels[n - 1].to_string(),
422 bucket: Some(bucket),
423 })
424 }
425 _ => None,
426 }
427}
428
429fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
432 let (prefix, action) = target.rsplit_once('.')?;
433
434 let service = match prefix {
435 "AWSEvents" => "events",
436 "AmazonSSM" => "ssm",
437 "AmazonSQS" => "sqs",
438 "AmazonSNS" => "sns",
439 "DynamoDB_20120810" => "dynamodb",
440 "DynamoDBStreams_20120810" => "dynamodbstreams",
441 "Logs_20140328" => "logs",
442 s if s.starts_with("secretsmanager") => "secretsmanager",
443 s if s.starts_with("TrentService") => "kms",
444 s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
445 s if s.starts_with("AWSCognitoIdentityService") => "cognito-identity",
446 s if s.starts_with("Kinesis_20131202") => "kinesis",
447 s if s.starts_with("AmazonEC2ContainerRegistry_V") => "ecr",
448 s if s.starts_with("AmazonEC2ContainerServiceV") => "ecs",
449 s if s.starts_with("AWSStepFunctions") => "states",
450 s if s.starts_with("AWSOrganizationsV") => "organizations",
451 "CertificateManager" => "acm",
452 "AnyScaleFrontendService" => "application-autoscaling",
453 "AWSWAF_20190729" => "wafv2",
456 "AmazonAthena" => "athena",
457 s if s.starts_with("Firehose_") => "firehose",
458 "AWSGlue" => "glue",
459 "CloudApiService" => "cloudcontrolapi",
460 "ResourceGroupsTaggingAPI_20170126" => "tagging",
461 "AmazonMemoryDB" => "memorydb",
462 "Route53AutoNaming_v20170314" => "servicediscovery",
465 "AmazonDMSv20160101" => "dms",
467 "CloudTrail_20131101" => "cloudtrail",
471 "com.amazonaws.cloudtrail.v20131101.CloudTrail_20131101" => "cloudtrail",
472 "AWSInsightsIndexService" => "ce",
476 "com.amazonaws.costexplorer.v20171025.AWSInsightsIndexService" => "ce",
477 "TransferService" => "transfer",
479 "AWSIdentityStore" => "identitystore",
481 "SWBExternalService" => "sso",
483 "VerifiedPermissions" => "verifiedpermissions",
485 s if s.starts_with("GraniteServiceVersion") => "monitoring",
491 _ => return None,
492 };
493
494 Some(DetectedRequest {
495 service: service.to_string(),
496 action: action.to_string(),
497 protocol: AwsProtocol::Json,
498 })
499}
500
501fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
503 if REST_XML_SERVICES.contains(&service) {
504 Some(AwsProtocol::Rest)
505 } else if REST_JSON_SERVICES.contains(&service) {
506 Some(AwsProtocol::RestJson)
507 } else {
508 None
509 }
510}
511
512fn infer_service_from_action(action: &str) -> Option<String> {
516 match action {
517 "AssumeRole"
518 | "AssumeRoleWithSAML"
519 | "AssumeRoleWithWebIdentity"
520 | "GetCallerIdentity"
521 | "GetSessionToken"
522 | "GetFederationToken"
523 | "GetAccessKeyInfo"
524 | "DecodeAuthorizationMessage" => Some("sts".to_string()),
525 "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
526 | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
527 | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
528 | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
529 "VerifyEmailIdentity"
531 | "VerifyDomainIdentity"
532 | "VerifyDomainDkim"
533 | "ListIdentities"
534 | "GetIdentityVerificationAttributes"
535 | "GetIdentityDkimAttributes"
536 | "DeleteIdentity"
537 | "SetIdentityDkimEnabled"
538 | "SetIdentityNotificationTopic"
539 | "SetIdentityFeedbackForwardingEnabled"
540 | "GetIdentityNotificationAttributes"
541 | "GetIdentityMailFromDomainAttributes"
542 | "SetIdentityMailFromDomain"
543 | "SendEmail"
544 | "SendRawEmail"
545 | "SendTemplatedEmail"
546 | "SendBulkTemplatedEmail"
547 | "CreateTemplate"
548 | "GetTemplate"
549 | "ListTemplates"
550 | "DeleteTemplate"
551 | "UpdateTemplate"
552 | "CreateConfigurationSet"
553 | "DeleteConfigurationSet"
554 | "DescribeConfigurationSet"
555 | "ListConfigurationSets"
556 | "CreateConfigurationSetEventDestination"
557 | "UpdateConfigurationSetEventDestination"
558 | "DeleteConfigurationSetEventDestination"
559 | "GetSendQuota"
560 | "GetSendStatistics"
561 | "GetAccountSendingEnabled"
562 | "CreateReceiptRuleSet"
563 | "DeleteReceiptRuleSet"
564 | "DescribeReceiptRuleSet"
565 | "ListReceiptRuleSets"
566 | "CloneReceiptRuleSet"
567 | "SetActiveReceiptRuleSet"
568 | "ReorderReceiptRuleSet"
569 | "CreateReceiptRule"
570 | "DeleteReceiptRule"
571 | "DescribeReceiptRule"
572 | "UpdateReceiptRule"
573 | "CreateReceiptFilter"
574 | "DeleteReceiptFilter"
575 | "ListReceiptFilters" => Some("ses".to_string()),
576 "ConfirmSubscription" | "Unsubscribe" => Some("sns".to_string()),
580 _ => None,
581 }
582}
583
584fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
586 let auth = headers.get("authorization")?.to_str().ok()?;
587 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
588 Some(normalize_service_name(&info.service).to_string())
589}
590
591fn normalize_service_name(service: &str) -> &str {
603 match service {
604 "bedrock-runtime" => "bedrock",
605 "apigatewayv2" => "apigateway",
613 "opensearch" => "es",
621 "appconfigdata" => "appconfig",
627 other => other,
628 }
629}
630
631pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
633 decode_form_urlencoded(body)
634}
635
636pub fn flatten_json_to_query(body: &Bytes) -> HashMap<String, String> {
653 let mut out = HashMap::new();
654 let Ok(value) = serde_json::from_slice::<serde_json::Value>(body) else {
655 return out;
656 };
657 if value.is_object() {
658 flatten_json_value("", &value, &mut out);
659 }
660 out
661}
662
663fn flatten_json_value(prefix: &str, value: &serde_json::Value, out: &mut HashMap<String, String>) {
664 match value {
665 serde_json::Value::Object(map) => {
666 for (k, v) in map {
667 let child = if prefix.is_empty() {
668 k.clone()
669 } else {
670 format!("{prefix}.{k}")
671 };
672 flatten_json_value(&child, v, out);
673 }
674 }
675 serde_json::Value::Array(items) => {
676 for (i, v) in items.iter().enumerate() {
677 let child = format!("{prefix}.member.{}", i + 1);
678 flatten_json_value(&child, v, out);
679 }
680 }
681 serde_json::Value::Null => {}
682 serde_json::Value::String(s) => {
683 out.insert(prefix.to_string(), s.clone());
684 }
685 serde_json::Value::Bool(b) => {
686 out.insert(prefix.to_string(), b.to_string());
687 }
688 serde_json::Value::Number(n) => {
689 out.insert(prefix.to_string(), n.to_string());
690 }
691 }
692}
693
694fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
695 let s = std::str::from_utf8(input).unwrap_or("");
696 let mut result = HashMap::new();
697 for pair in s.split('&') {
698 if pair.is_empty() {
699 continue;
700 }
701 let (key, value) = match pair.find('=') {
702 Some(pos) => (&pair[..pos], &pair[pos + 1..]),
703 None => (pair, ""),
704 };
705 result.insert(url_decode(key), url_decode(value));
706 }
707 result
708}
709
710fn url_decode(input: &str) -> String {
711 let mut result = String::with_capacity(input.len());
712 let mut bytes = input.bytes();
713 while let Some(b) = bytes.next() {
714 match b {
715 b'+' => result.push(' '),
716 b'%' => {
717 let high = bytes.next().and_then(from_hex);
718 let low = bytes.next().and_then(from_hex);
719 if let (Some(h), Some(l)) = (high, low) {
720 result.push((h << 4 | l) as char);
721 }
722 }
723 _ => result.push(b as char),
724 }
725 }
726 result
727}
728
729fn from_hex(b: u8) -> Option<u8> {
730 match b {
731 b'0'..=b'9' => Some(b - b'0'),
732 b'a'..=b'f' => Some(b - b'a' + 10),
733 b'A'..=b'F' => Some(b - b'A' + 10),
734 _ => None,
735 }
736}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741
742 #[test]
743 fn parse_amz_target_events() {
744 let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
745 assert_eq!(result.service, "events");
746 assert_eq!(result.action, "PutEvents");
747 assert_eq!(result.protocol, AwsProtocol::Json);
748 }
749
750 #[test]
751 fn parse_amz_target_ssm() {
752 let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
753 assert_eq!(result.service, "ssm");
754 assert_eq!(result.action, "GetParameter");
755 }
756
757 #[test]
758 fn parse_amz_target_kinesis() {
759 let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
760 assert_eq!(result.service, "kinesis");
761 assert_eq!(result.action, "ListStreams");
762 assert_eq!(result.protocol, AwsProtocol::Json);
763 }
764
765 #[test]
766 fn parse_query_body_basic() {
767 let body = Bytes::from(
768 "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
769 );
770 let params = parse_query_body(&body);
771 assert_eq!(params.get("Action").unwrap(), "SendMessage");
772 assert_eq!(params.get("MessageBody").unwrap(), "hello");
773 }
774
775 #[test]
776 fn parse_query_body_empty_returns_empty_map() {
777 let body = Bytes::from("");
778 let params = parse_query_body(&body);
779 assert!(params.is_empty());
780 }
781
782 #[test]
783 fn parse_query_body_duplicate_keys_last_wins() {
784 let body = Bytes::from("key=a&key=b");
785 let params = parse_query_body(&body);
786 assert_eq!(params.get("key").unwrap(), "b");
787 }
788
789 #[test]
790 fn parse_query_body_single_key() {
791 let body = Bytes::from("key=value");
792 let params = parse_query_body(&body);
793 assert_eq!(params.get("key").unwrap(), "value");
794 }
795
796 #[test]
797 fn parse_amz_target_ecs() {
798 let result = parse_amz_target("AmazonEC2ContainerServiceV20141113.ListClusters").unwrap();
799 assert_eq!(result.service, "ecs");
800 assert_eq!(result.action, "ListClusters");
801 assert_eq!(result.protocol, AwsProtocol::Json);
802 }
803
804 #[test]
805 fn parse_amz_target_invalid_returns_none() {
806 assert!(parse_amz_target("NoDotHere").is_none());
807 assert!(parse_amz_target("").is_none());
808 }
809
810 #[test]
811 fn parse_amz_target_cloudwatch_json() {
812 let result = parse_amz_target("GraniteServiceVersion20100801.PutMetricData").unwrap();
814 assert_eq!(result.service, "monitoring");
815 assert_eq!(result.action, "PutMetricData");
816 assert_eq!(result.protocol, AwsProtocol::Json);
817 }
818
819 #[test]
820 fn flatten_json_to_query_nested() {
821 let body = Bytes::from(
822 serde_json::json!({
823 "Namespace": "MyApp",
824 "MetricData": [{
825 "MetricName": "Latency",
826 "Value": 12.5,
827 "StatisticValues": {"SampleCount": 3, "Sum": 10},
828 "Dimensions": [{"Name": "Endpoint", "Value": "/api"}]
829 }]
830 })
831 .to_string(),
832 );
833 let flat = flatten_json_to_query(&body);
834 assert_eq!(flat.get("Namespace").unwrap(), "MyApp");
835 assert_eq!(
836 flat.get("MetricData.member.1.MetricName").unwrap(),
837 "Latency"
838 );
839 assert_eq!(flat.get("MetricData.member.1.Value").unwrap(), "12.5");
840 assert_eq!(
841 flat.get("MetricData.member.1.StatisticValues.SampleCount")
842 .unwrap(),
843 "3"
844 );
845 assert_eq!(
846 flat.get("MetricData.member.1.Dimensions.member.1.Name")
847 .unwrap(),
848 "Endpoint"
849 );
850 assert_eq!(
851 flat.get("MetricData.member.1.Dimensions.member.1.Value")
852 .unwrap(),
853 "/api"
854 );
855 }
856
857 #[test]
858 fn flatten_json_to_query_non_object_is_empty() {
859 assert!(flatten_json_to_query(&Bytes::from_static(b"[]")).is_empty());
860 assert!(flatten_json_to_query(&Bytes::from_static(b"not json")).is_empty());
861 }
862
863 #[test]
864 fn parse_amz_target_various_prefixes() {
865 assert_eq!(
866 parse_amz_target("AmazonSQS.SendMessage").unwrap().service,
867 "sqs"
868 );
869 assert_eq!(
870 parse_amz_target("AmazonSNS.Publish").unwrap().service,
871 "sns"
872 );
873 assert_eq!(
874 parse_amz_target("DynamoDB_20120810.GetItem")
875 .unwrap()
876 .service,
877 "dynamodb"
878 );
879 assert_eq!(
880 parse_amz_target("Logs_20140328.PutLogEvents")
881 .unwrap()
882 .service,
883 "logs"
884 );
885 assert_eq!(
886 parse_amz_target("secretsmanager.GetSecretValue")
887 .unwrap()
888 .service,
889 "secretsmanager"
890 );
891 assert_eq!(
892 parse_amz_target("TrentService.Encrypt").unwrap().service,
893 "kms"
894 );
895 assert_eq!(
896 parse_amz_target("AWSCognitoIdentityProviderService.InitiateAuth")
897 .unwrap()
898 .service,
899 "cognito-idp"
900 );
901 assert_eq!(
902 parse_amz_target("AWSStepFunctions.StartExecution")
903 .unwrap()
904 .service,
905 "states"
906 );
907 assert_eq!(
908 parse_amz_target("AWSOrganizationsV20161128.CreateOrganization")
909 .unwrap()
910 .service,
911 "organizations"
912 );
913 assert!(parse_amz_target("UnknownServicePrefix.Action").is_none());
914 }
915
916 #[test]
917 fn infer_service_from_action_maps_sts() {
918 assert_eq!(
919 infer_service_from_action("AssumeRole").as_deref(),
920 Some("sts")
921 );
922 assert_eq!(
923 infer_service_from_action("GetCallerIdentity").as_deref(),
924 Some("sts")
925 );
926 }
927
928 #[test]
929 fn infer_service_from_action_maps_iam() {
930 assert_eq!(
931 infer_service_from_action("CreateUser").as_deref(),
932 Some("iam")
933 );
934 assert_eq!(
935 infer_service_from_action("ListRoles").as_deref(),
936 Some("iam")
937 );
938 }
939
940 #[test]
941 fn infer_service_from_action_maps_ses() {
942 assert_eq!(
943 infer_service_from_action("SendEmail").as_deref(),
944 Some("ses")
945 );
946 assert_eq!(
947 infer_service_from_action("ListIdentities").as_deref(),
948 Some("ses")
949 );
950 }
951
952 #[test]
953 fn infer_service_from_action_maps_sns_confirmation_flow() {
954 assert_eq!(
957 infer_service_from_action("ConfirmSubscription").as_deref(),
958 Some("sns")
959 );
960 assert_eq!(
961 infer_service_from_action("Unsubscribe").as_deref(),
962 Some("sns")
963 );
964 }
965
966 #[test]
967 fn detect_service_routes_unsigned_confirm_subscription_to_sns() {
968 let mut headers = HeaderMap::new();
971 headers.insert("host", "localhost:4566".parse().unwrap());
972 let mut query_params = HashMap::new();
973 query_params.insert("Action".to_string(), "ConfirmSubscription".to_string());
974 query_params.insert(
975 "TopicArn".to_string(),
976 "arn:aws:sns:us-east-1:000000000000:t".to_string(),
977 );
978 query_params.insert("Token".to_string(), "abc123".to_string());
979
980 let detected = detect_service(&headers, &query_params, &Bytes::new())
981 .expect("ConfirmSubscription must route to a service");
982 assert_eq!(detected.service, "sns");
983 assert_eq!(detected.action, "ConfirmSubscription");
984 assert_eq!(detected.protocol, AwsProtocol::Query);
985 }
986
987 #[test]
988 fn infer_service_from_action_unknown_returns_none() {
989 assert!(infer_service_from_action("NotARealAction").is_none());
990 }
991
992 #[test]
993 fn rest_protocol_for_returns_none_for_non_rest_service() {
994 assert!(rest_protocol_for("sqs").is_none());
995 }
996
997 #[test]
998 fn url_decode_handles_percent_and_plus() {
999 assert_eq!(url_decode("hello+world"), "hello world");
1000 assert_eq!(url_decode("hello%20world"), "hello world");
1001 assert_eq!(url_decode("100%25"), "100%");
1002 }
1003
1004 #[test]
1005 fn url_decode_ignores_malformed_percent() {
1006 assert_eq!(url_decode("%ZZ"), "");
1007 }
1008
1009 #[test]
1010 fn from_hex_valid_digits() {
1011 assert_eq!(from_hex(b'0'), Some(0));
1012 assert_eq!(from_hex(b'9'), Some(9));
1013 assert_eq!(from_hex(b'a'), Some(10));
1014 assert_eq!(from_hex(b'F'), Some(15));
1015 }
1016
1017 #[test]
1018 fn from_hex_invalid_returns_none() {
1019 assert!(from_hex(b'g').is_none());
1020 assert!(from_hex(b' ').is_none());
1021 }
1022
1023 #[test]
1024 fn detect_service_via_amz_target() {
1025 let mut headers = HeaderMap::new();
1026 headers.insert("x-amz-target", "AmazonSSM.GetParameter".parse().unwrap());
1027 let query = HashMap::new();
1028 let body = Bytes::new();
1029 let detected = detect_service(&headers, &query, &body).unwrap();
1030 assert_eq!(detected.service, "ssm");
1031 assert_eq!(detected.action, "GetParameter");
1032 }
1033
1034 #[test]
1035 fn detect_service_via_query_action_with_inferred_service() {
1036 let headers = HeaderMap::new();
1037 let mut query = HashMap::new();
1038 query.insert("Action".to_string(), "AssumeRole".to_string());
1039 let body = Bytes::new();
1040 let detected = detect_service(&headers, &query, &body).unwrap();
1041 assert_eq!(detected.service, "sts");
1042 assert_eq!(detected.action, "AssumeRole");
1043 assert_eq!(detected.protocol, AwsProtocol::Query);
1044 }
1045
1046 #[test]
1047 fn detect_service_via_form_body() {
1048 let headers = HeaderMap::new();
1049 let query = HashMap::new();
1050 let body = Bytes::from("Action=SendEmail&Source=x%40y.com");
1051 let detected = detect_service(&headers, &query, &body).unwrap();
1052 assert_eq!(detected.service, "ses");
1053 assert_eq!(detected.action, "SendEmail");
1054 }
1055
1056 #[test]
1057 fn detect_service_via_sigv2_presigned() {
1058 let headers = HeaderMap::new();
1059 let mut query = HashMap::new();
1060 query.insert("AWSAccessKeyId".to_string(), "AKID".to_string());
1061 query.insert("Signature".to_string(), "sig".to_string());
1062 query.insert("Expires".to_string(), "1234567890".to_string());
1063 let body = Bytes::new();
1064 let detected = detect_service(&headers, &query, &body).unwrap();
1065 assert_eq!(detected.service, "s3");
1066 assert_eq!(detected.protocol, AwsProtocol::Rest);
1067 }
1068
1069 #[test]
1070 fn detect_service_via_sigv4_presigned_credential() {
1071 let headers = HeaderMap::new();
1072 let mut query = HashMap::new();
1073 query.insert(
1074 "X-Amz-Credential".to_string(),
1075 "AKID/20240101/us-east-1/s3/aws4_request".to_string(),
1076 );
1077 let body = Bytes::new();
1078 let detected = detect_service(&headers, &query, &body).unwrap();
1079 assert_eq!(detected.service, "s3");
1080 assert_eq!(detected.protocol, AwsProtocol::Rest);
1081 }
1082
1083 #[test]
1084 fn detect_service_unknown_returns_none() {
1085 let headers = HeaderMap::new();
1086 let query = HashMap::new();
1087 let body = Bytes::new();
1088 assert!(detect_service(&headers, &query, &body).is_none());
1089 }
1090
1091 #[test]
1092 fn normalize_service_name_aliases_apigatewayv2_to_apigateway() {
1093 assert_eq!(normalize_service_name("apigatewayv2"), "apigateway");
1098 }
1099
1100 #[test]
1101 fn normalize_service_name_aliases_bedrock_runtime_to_bedrock() {
1102 assert_eq!(normalize_service_name("bedrock-runtime"), "bedrock");
1107 }
1108
1109 #[test]
1110 fn normalize_service_name_passes_through_unaliased_services() {
1111 assert_eq!(normalize_service_name("bedrock"), "bedrock");
1115 assert_eq!(normalize_service_name("s3"), "s3");
1116 assert_eq!(normalize_service_name("lambda"), "lambda");
1117 assert_eq!(normalize_service_name(""), "");
1118 assert_eq!(
1119 normalize_service_name("unknown-future-service"),
1120 "unknown-future-service"
1121 );
1122 }
1123
1124 #[test]
1125 fn detect_service_via_authorization_header_normalizes_bedrock_runtime() {
1126 let mut headers = HeaderMap::new();
1131 headers.insert(
1132 "authorization",
1133 "AWS4-HMAC-SHA256 \
1134 Credential=AKID/20240101/us-east-1/bedrock-runtime/aws4_request, \
1135 SignedHeaders=host, Signature=abc"
1136 .parse()
1137 .unwrap(),
1138 );
1139 let query = HashMap::new();
1140 let body = Bytes::new();
1141 let detected = detect_service(&headers, &query, &body).unwrap();
1142 assert_eq!(detected.service, "bedrock");
1143 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1144 }
1145
1146 #[test]
1147 fn detect_service_via_sigv4_presigned_credential_normalizes_bedrock_runtime() {
1148 let headers = HeaderMap::new();
1152 let mut query = HashMap::new();
1153 query.insert(
1154 "X-Amz-Credential".to_string(),
1155 "AKID/20240101/us-east-1/bedrock-runtime/aws4_request".to_string(),
1156 );
1157 let body = Bytes::new();
1158 let detected = detect_service(&headers, &query, &body).unwrap();
1159 assert_eq!(detected.service, "bedrock");
1160 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1161 }
1162
1163 #[test]
1164 fn parse_routing_host_localstack_basic() {
1165 let h = parse_routing_host("sqs.us-east-1.localhost.localstack.cloud").unwrap();
1166 assert_eq!(h.service, "sqs");
1167 assert_eq!(h.region, "us-east-1");
1168 assert!(h.bucket.is_none());
1169 }
1170
1171 #[test]
1172 fn parse_routing_host_localstack_with_port() {
1173 let h = parse_routing_host("lambda.eu-west-1.localhost.localstack.cloud:4566").unwrap();
1174 assert_eq!(h.service, "lambda");
1175 assert_eq!(h.region, "eu-west-1");
1176 assert!(h.bucket.is_none());
1177 }
1178
1179 #[test]
1180 fn parse_routing_host_case_insensitive() {
1181 let h = parse_routing_host("SQS.US-EAST-1.LOCALHOST.LOCALSTACK.CLOUD:4566").unwrap();
1182 assert_eq!(h.service, "sqs");
1183 assert_eq!(h.region, "us-east-1");
1184
1185 let h = parse_routing_host("LAMBDA.US-EAST-1.AMAZONAWS.COM").unwrap();
1186 assert_eq!(h.service, "lambda");
1187 assert_eq!(h.region, "us-east-1");
1188 }
1189
1190 #[test]
1191 fn parse_routing_host_localstack_s3_virtual_hosted() {
1192 let h =
1193 parse_routing_host("my-bucket.s3.us-east-1.localhost.localstack.cloud:4566").unwrap();
1194 assert_eq!(h.service, "s3");
1195 assert_eq!(h.region, "us-east-1");
1196 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1197 }
1198
1199 #[test]
1200 fn parse_routing_host_localstack_s3_vhost_bucket_with_dots() {
1201 let h = parse_routing_host("a.b.c.s3.us-east-1.localhost.localstack.cloud").unwrap();
1202 assert_eq!(h.service, "s3");
1203 assert_eq!(h.region, "us-east-1");
1204 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1205 }
1206
1207 #[test]
1208 fn parse_routing_host_aws_service_region() {
1209 let h = parse_routing_host("sqs.us-east-1.amazonaws.com").unwrap();
1210 assert_eq!(h.service, "sqs");
1211 assert_eq!(h.region, "us-east-1");
1212 assert!(h.bucket.is_none());
1213
1214 let h = parse_routing_host("dynamodb.eu-west-2.amazonaws.com:443").unwrap();
1215 assert_eq!(h.service, "dynamodb");
1216 assert_eq!(h.region, "eu-west-2");
1217 }
1218
1219 #[test]
1220 fn parse_routing_host_aws_s3_path_style_modern() {
1221 let h = parse_routing_host("s3.us-east-1.amazonaws.com").unwrap();
1222 assert_eq!(h.service, "s3");
1223 assert_eq!(h.region, "us-east-1");
1224 assert!(h.bucket.is_none());
1225 }
1226
1227 #[test]
1228 fn parse_routing_host_aws_s3_virtual_hosted_modern() {
1229 let h = parse_routing_host("my-bucket.s3.us-east-1.amazonaws.com").unwrap();
1230 assert_eq!(h.service, "s3");
1231 assert_eq!(h.region, "us-east-1");
1232 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1233 }
1234
1235 #[test]
1236 fn parse_routing_host_aws_s3_vhost_bucket_with_dots() {
1237 let h = parse_routing_host("a.b.c.s3.us-east-1.amazonaws.com").unwrap();
1238 assert_eq!(h.service, "s3");
1239 assert_eq!(h.region, "us-east-1");
1240 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1241 }
1242
1243 #[test]
1244 fn parse_routing_host_aws_s3_legacy_global() {
1245 let h = parse_routing_host("s3.amazonaws.com").unwrap();
1248 assert_eq!(h.service, "s3");
1249 assert_eq!(h.region, "us-east-1");
1250 assert!(h.bucket.is_none());
1251
1252 let h = parse_routing_host("my-bucket.s3.amazonaws.com").unwrap();
1253 assert_eq!(h.service, "s3");
1254 assert_eq!(h.region, "us-east-1");
1255 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1256 }
1257
1258 #[test]
1259 fn parse_routing_host_aws_s3_legacy_global_dotted_bucket() {
1260 let h = parse_routing_host("a.b.c.s3.amazonaws.com").unwrap();
1263 assert_eq!(h.service, "s3");
1264 assert_eq!(h.region, "us-east-1");
1265 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1266 }
1267
1268 #[test]
1269 fn parse_routing_host_aws_s3_dash_separated() {
1270 let h = parse_routing_host("s3-us-west-2.amazonaws.com").unwrap();
1272 assert_eq!(h.service, "s3");
1273 assert_eq!(h.region, "us-west-2");
1274 assert!(h.bucket.is_none());
1275
1276 let h = parse_routing_host("my-bucket.s3-us-west-2.amazonaws.com").unwrap();
1277 assert_eq!(h.service, "s3");
1278 assert_eq!(h.region, "us-west-2");
1279 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1280 }
1281
1282 #[test]
1283 fn parse_routing_host_rejects_plain_localhost() {
1284 assert!(parse_routing_host("localhost:4566").is_none());
1285 assert!(parse_routing_host("127.0.0.1:4566").is_none());
1286 }
1287
1288 #[test]
1289 fn parse_routing_host_rejects_unknown_suffix() {
1290 assert!(parse_routing_host("sqs.us-east-1.example.com").is_none());
1291 assert!(parse_routing_host("s3.us-east-1.aws").is_none());
1292 }
1293
1294 #[test]
1295 fn parse_routing_host_empty_and_malformed_rejected() {
1296 assert!(parse_routing_host("").is_none());
1297 assert!(parse_routing_host(".localhost.localstack.cloud").is_none());
1298 assert!(parse_routing_host("..localhost.localstack.cloud").is_none());
1299 assert!(parse_routing_host("sqs.localhost.localstack.cloud").is_none());
1300 assert!(parse_routing_host("foo.bar.baz.localhost.localstack.cloud").is_none());
1301 assert!(parse_routing_host(".amazonaws.com").is_none());
1302 assert!(parse_routing_host("amazonaws.com").is_none());
1303 }
1304
1305 #[test]
1306 fn parse_routing_host_bare_s3_accesspoint_does_not_panic() {
1307 assert!(parse_routing_host("s3-accesspoint").is_none());
1311 }
1312
1313 #[test]
1314 fn detect_service_via_host_for_rest_service() {
1315 let mut headers = HeaderMap::new();
1316 headers.insert(
1317 "host",
1318 "s3.us-east-1.localhost.localstack.cloud:4566"
1319 .parse()
1320 .unwrap(),
1321 );
1322 let query = HashMap::new();
1323 let body = Bytes::new();
1324 let detected = detect_service(&headers, &query, &body).unwrap();
1325 assert_eq!(detected.service, "s3");
1326 assert_eq!(detected.protocol, AwsProtocol::Rest);
1327 }
1328
1329 #[test]
1330 fn detect_service_via_host_for_rest_json_service() {
1331 let mut headers = HeaderMap::new();
1332 headers.insert(
1333 "host",
1334 "lambda.us-east-1.localhost.localstack.cloud:4566"
1335 .parse()
1336 .unwrap(),
1337 );
1338 let query = HashMap::new();
1339 let body = Bytes::new();
1340 let detected = detect_service(&headers, &query, &body).unwrap();
1341 assert_eq!(detected.service, "lambda");
1342 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1343 }
1344
1345 #[test]
1346 fn detect_service_via_host_plus_query_action() {
1347 let mut headers = HeaderMap::new();
1348 headers.insert(
1349 "host",
1350 "sqs.us-east-1.localhost.localstack.cloud:4566"
1351 .parse()
1352 .unwrap(),
1353 );
1354 let mut query = HashMap::new();
1355 query.insert("Action".to_string(), "ListQueues".to_string());
1356 let body = Bytes::new();
1357 let detected = detect_service(&headers, &query, &body).unwrap();
1358 assert_eq!(detected.service, "sqs");
1359 assert_eq!(detected.action, "ListQueues");
1360 assert_eq!(detected.protocol, AwsProtocol::Query);
1361 }
1362
1363 #[test]
1364 fn detect_service_sigv4_wins_over_host() {
1365 let mut headers = HeaderMap::new();
1366 headers.insert(
1367 "authorization",
1368 "AWS4-HMAC-SHA256 Credential=AKID/20240101/us-east-1/s3/aws4_request, \
1369 SignedHeaders=host, Signature=abc"
1370 .parse()
1371 .unwrap(),
1372 );
1373 headers.insert(
1374 "host",
1375 "lambda.us-east-1.localhost.localstack.cloud:4566"
1376 .parse()
1377 .unwrap(),
1378 );
1379 let query = HashMap::new();
1380 let body = Bytes::new();
1381 let detected = detect_service(&headers, &query, &body).unwrap();
1382 assert_eq!(detected.service, "s3");
1384 assert_eq!(detected.protocol, AwsProtocol::Rest);
1385 }
1386
1387 #[test]
1388 fn detect_service_host_for_virtual_hosted_s3() {
1389 let mut headers = HeaderMap::new();
1390 headers.insert(
1391 "host",
1392 "my-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1393 .parse()
1394 .unwrap(),
1395 );
1396 let query = HashMap::new();
1397 let body = Bytes::new();
1398 let detected = detect_service(&headers, &query, &body).unwrap();
1399 assert_eq!(detected.service, "s3");
1400 assert_eq!(detected.protocol, AwsProtocol::Rest);
1401 }
1402}