1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8 Query,
11 Json,
14 Rest,
17 RestJson,
20}
21
22const REST_XML_SERVICES: &[&str] = &["s3", "cloudfront", "route53"];
24
25const REST_JSON_SERVICES: &[&str] = &[
27 "lambda",
28 "ses",
29 "apigateway",
30 "bedrock",
31 "bedrock-agent",
32 "bedrock-agent-runtime",
33 "scheduler",
34 "batch",
35 "pipes",
36 "rds-data",
37 "dsql",
38 "resource-groups",
39];
40
41#[derive(Debug, Clone)]
43pub struct DetectedRequest {
44 pub service: String,
45 pub action: String,
46 pub protocol: AwsProtocol,
47}
48
49pub fn detect_service_headers_only(
56 headers: &HeaderMap,
57 query_params: &HashMap<String, String>,
58) -> Option<DetectedRequest> {
59 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
61 return parse_amz_target(target);
62 }
63 if let Some(action) = query_params.get("Action") {
64 let service = extract_service_from_auth(headers)
65 .or_else(|| infer_service_from_action(action))
66 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
67 if let Some(service) = service {
68 return Some(DetectedRequest {
69 service,
70 action: action.clone(),
71 protocol: AwsProtocol::Query,
72 });
73 }
74 }
75 if let Some(service) = extract_service_from_auth(headers) {
76 if let Some(protocol) = rest_protocol_for(&service) {
77 return Some(DetectedRequest {
78 service,
79 action: String::new(),
80 protocol,
81 });
82 }
83 }
84 if let Some(credential) = query_params.get("X-Amz-Credential") {
85 let parts: Vec<&str> = credential.split('/').collect();
86 if parts.len() >= 4 {
87 let service = normalize_service_name(parts[3]).to_string();
88 if let Some(protocol) = rest_protocol_for(&service) {
89 return Some(DetectedRequest {
90 service,
91 action: String::new(),
92 protocol,
93 });
94 }
95 }
96 }
97 if query_params.contains_key("AWSAccessKeyId")
98 && query_params.contains_key("Signature")
99 && query_params.contains_key("Expires")
100 {
101 return Some(DetectedRequest {
102 service: "s3".to_string(),
103 action: String::new(),
104 protocol: AwsProtocol::Rest,
105 });
106 }
107 if let Some(host_info) = parse_routing_host_from_headers(headers) {
108 if let Some(protocol) = rest_protocol_for(&host_info.service) {
109 return Some(DetectedRequest {
110 service: host_info.service,
111 action: String::new(),
112 protocol,
113 });
114 }
115 }
116 None
117}
118
119pub fn detect_service(
121 headers: &HeaderMap,
122 query_params: &HashMap<String, String>,
123 body: &Bytes,
124) -> Option<DetectedRequest> {
125 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
127 return parse_amz_target(target);
128 }
129
130 if let Some(action) = query_params.get("Action") {
132 let service = extract_service_from_auth(headers)
133 .or_else(|| infer_service_from_action(action))
134 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
135 if let Some(service) = service {
136 return Some(DetectedRequest {
137 service,
138 action: action.clone(),
139 protocol: AwsProtocol::Query,
140 });
141 }
142 }
143
144 {
146 let form_params = decode_form_urlencoded(body);
147
148 if let Some(action) = form_params.get("Action") {
149 let service = extract_service_from_auth(headers)
150 .or_else(|| infer_service_from_action(action))
151 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
152 if let Some(service) = service {
153 return Some(DetectedRequest {
154 service,
155 action: action.clone(),
156 protocol: AwsProtocol::Query,
157 });
158 }
159 }
160 }
161
162 if let Some(service) = extract_service_from_auth(headers) {
164 if let Some(protocol) = rest_protocol_for(&service) {
165 return Some(DetectedRequest {
166 service,
167 action: String::new(), protocol,
169 });
170 }
171 }
172
173 if let Some(credential) = query_params.get("X-Amz-Credential") {
175 let parts: Vec<&str> = credential.split('/').collect();
177 if parts.len() >= 4 {
178 let service = normalize_service_name(parts[3]).to_string();
179 if let Some(protocol) = rest_protocol_for(&service) {
180 return Some(DetectedRequest {
181 service,
182 action: String::new(),
183 protocol,
184 });
185 }
186 }
187 }
188
189 if query_params.contains_key("AWSAccessKeyId")
193 && query_params.contains_key("Signature")
194 && query_params.contains_key("Expires")
195 {
196 return Some(DetectedRequest {
197 service: "s3".to_string(),
198 action: String::new(),
199 protocol: AwsProtocol::Rest,
200 });
201 }
202
203 if let Some(host_info) = parse_routing_host_from_headers(headers) {
207 if let Some(protocol) = rest_protocol_for(&host_info.service) {
208 return Some(DetectedRequest {
209 service: host_info.service,
210 action: String::new(),
211 protocol,
212 });
213 }
214 }
215
216 None
217}
218
219#[derive(Debug, Clone, PartialEq, Eq)]
228pub struct RoutingHost {
229 pub service: String,
230 pub region: String,
231 pub bucket: Option<String>,
233}
234
235const LOCALSTACK_SUFFIX: &str = ".localhost.localstack.cloud";
236const AWS_SUFFIX: &str = ".amazonaws.com";
237
238pub fn parse_routing_host(host: &str) -> Option<RoutingHost> {
242 let hostname = host.split(':').next()?;
243 if hostname.is_empty() {
244 return None;
245 }
246 let hostname = hostname.to_ascii_lowercase();
247 if let Some(prefix) = hostname.strip_suffix(LOCALSTACK_SUFFIX) {
248 return parse_localstack_prefix(prefix);
249 }
250 if hostname == "amazonaws.com" {
251 return None;
252 }
253 if let Some(prefix) = hostname.strip_suffix(AWS_SUFFIX) {
254 return parse_aws_prefix(prefix);
255 }
256 None
257}
258
259pub fn parse_routing_host_from_headers(headers: &HeaderMap) -> Option<RoutingHost> {
261 let host = headers.get("host")?.to_str().ok()?;
262 parse_routing_host(host)
263}
264
265fn parse_localstack_prefix(prefix: &str) -> Option<RoutingHost> {
266 if prefix.is_empty() {
267 return None;
268 }
269 let labels: Vec<&str> = prefix.split('.').collect();
270 if labels.iter().any(|l| l.is_empty()) {
271 return None;
272 }
273 match labels.len() {
274 2 => Some(RoutingHost {
275 service: labels[0].to_string(),
276 region: labels[1].to_string(),
277 bucket: None,
278 }),
279 n if n >= 3 && labels[n - 2] == "s3" => {
280 let bucket = labels[..n - 2].join(".");
281 Some(RoutingHost {
282 service: "s3".to_string(),
283 region: labels[n - 1].to_string(),
284 bucket: Some(bucket),
285 })
286 }
287 n if n >= 3 && labels[n - 2] == "s3-accesspoint" => {
288 let bucket = labels[..n - 2].join(".");
289 Some(RoutingHost {
290 service: "s3".to_string(),
291 region: labels[n - 1].to_string(),
292 bucket: Some(bucket),
293 })
294 }
295 n if n >= 3 && labels[n - 2] == "s3-control" => Some(RoutingHost {
296 service: "s3".to_string(),
297 region: labels[n - 1].to_string(),
298 bucket: None,
299 }),
300 _ => None,
301 }
302}
303
304fn parse_aws_prefix(prefix: &str) -> Option<RoutingHost> {
316 if prefix.is_empty() {
317 return None;
318 }
319 let labels: Vec<&str> = prefix.split('.').collect();
320 if labels.iter().any(|l| l.is_empty()) {
321 return None;
322 }
323 let last = *labels.last()?;
324
325 if let Some(region) = last.strip_prefix("s3-") {
328 if !region.is_empty() {
329 let bucket = if labels.len() >= 2 {
330 Some(labels[..labels.len() - 1].join("."))
331 } else {
332 None
333 };
334 return Some(RoutingHost {
335 service: "s3".to_string(),
336 region: region.to_string(),
337 bucket,
338 });
339 }
340 }
341
342 if last == "s3" {
346 if labels.len() == 1 {
347 return Some(RoutingHost {
348 service: "s3".to_string(),
349 region: "us-east-1".to_string(),
350 bucket: None,
351 });
352 }
353 return Some(RoutingHost {
354 service: "s3".to_string(),
355 region: "us-east-1".to_string(),
356 bucket: Some(labels[..labels.len() - 1].join(".")),
357 });
358 }
359
360 if last == "s3-accesspoint" {
363 if labels.len() == 2 {
364 return Some(RoutingHost {
365 service: "s3".to_string(),
366 region: labels[0].to_string(),
367 bucket: None,
368 });
369 }
370 if labels.len() >= 3 {
374 let bucket = labels[..labels.len() - 2].join(".");
375 return Some(RoutingHost {
376 service: "s3".to_string(),
377 region: labels[labels.len() - 1].to_string(),
378 bucket: Some(bucket),
379 });
380 }
381 }
382
383 if labels.len() >= 2 && labels[labels.len() - 2] == "s3-control" {
386 return Some(RoutingHost {
387 service: "s3".to_string(),
388 region: last.to_string(),
389 bucket: None,
390 });
391 }
392
393 match labels.len() {
394 2 => Some(RoutingHost {
397 service: labels[0].to_string(),
398 region: labels[1].to_string(),
399 bucket: None,
400 }),
401 n if n >= 3 && labels[n - 2] == "s3" => {
403 let bucket = labels[..n - 2].join(".");
404 Some(RoutingHost {
405 service: "s3".to_string(),
406 region: labels[n - 1].to_string(),
407 bucket: Some(bucket),
408 })
409 }
410 _ => None,
411 }
412}
413
414fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
417 let (prefix, action) = target.rsplit_once('.')?;
418
419 let service = match prefix {
420 "AWSEvents" => "events",
421 "AmazonSSM" => "ssm",
422 "AmazonSQS" => "sqs",
423 "AmazonSNS" => "sns",
424 "DynamoDB_20120810" => "dynamodb",
425 "DynamoDBStreams_20120810" => "dynamodbstreams",
426 "Logs_20140328" => "logs",
427 s if s.starts_with("secretsmanager") => "secretsmanager",
428 s if s.starts_with("TrentService") => "kms",
429 s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
430 s if s.starts_with("AWSCognitoIdentityService") => "cognito-identity",
431 s if s.starts_with("Kinesis_20131202") => "kinesis",
432 s if s.starts_with("AmazonEC2ContainerRegistry_V") => "ecr",
433 s if s.starts_with("AmazonEC2ContainerServiceV") => "ecs",
434 s if s.starts_with("AWSStepFunctions") => "states",
435 s if s.starts_with("AWSOrganizationsV") => "organizations",
436 "CertificateManager" => "acm",
437 "AnyScaleFrontendService" => "application-autoscaling",
438 "AWSWAF_20190729" => "wafv2",
441 "AmazonAthena" => "athena",
442 s if s.starts_with("Firehose_") => "firehose",
443 "AWSGlue" => "glue",
444 "CloudApiService" => "cloudcontrolapi",
445 "ResourceGroupsTaggingAPI_20170126" => "tagging",
446 "AmazonMemoryDB" => "memorydb",
447 s if s.starts_with("GraniteServiceVersion") => "monitoring",
453 _ => return None,
454 };
455
456 Some(DetectedRequest {
457 service: service.to_string(),
458 action: action.to_string(),
459 protocol: AwsProtocol::Json,
460 })
461}
462
463fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
465 if REST_XML_SERVICES.contains(&service) {
466 Some(AwsProtocol::Rest)
467 } else if REST_JSON_SERVICES.contains(&service) {
468 Some(AwsProtocol::RestJson)
469 } else {
470 None
471 }
472}
473
474fn infer_service_from_action(action: &str) -> Option<String> {
478 match action {
479 "AssumeRole"
480 | "AssumeRoleWithSAML"
481 | "AssumeRoleWithWebIdentity"
482 | "GetCallerIdentity"
483 | "GetSessionToken"
484 | "GetFederationToken"
485 | "GetAccessKeyInfo"
486 | "DecodeAuthorizationMessage" => Some("sts".to_string()),
487 "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
488 | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
489 | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
490 | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
491 "VerifyEmailIdentity"
493 | "VerifyDomainIdentity"
494 | "VerifyDomainDkim"
495 | "ListIdentities"
496 | "GetIdentityVerificationAttributes"
497 | "GetIdentityDkimAttributes"
498 | "DeleteIdentity"
499 | "SetIdentityDkimEnabled"
500 | "SetIdentityNotificationTopic"
501 | "SetIdentityFeedbackForwardingEnabled"
502 | "GetIdentityNotificationAttributes"
503 | "GetIdentityMailFromDomainAttributes"
504 | "SetIdentityMailFromDomain"
505 | "SendEmail"
506 | "SendRawEmail"
507 | "SendTemplatedEmail"
508 | "SendBulkTemplatedEmail"
509 | "CreateTemplate"
510 | "GetTemplate"
511 | "ListTemplates"
512 | "DeleteTemplate"
513 | "UpdateTemplate"
514 | "CreateConfigurationSet"
515 | "DeleteConfigurationSet"
516 | "DescribeConfigurationSet"
517 | "ListConfigurationSets"
518 | "CreateConfigurationSetEventDestination"
519 | "UpdateConfigurationSetEventDestination"
520 | "DeleteConfigurationSetEventDestination"
521 | "GetSendQuota"
522 | "GetSendStatistics"
523 | "GetAccountSendingEnabled"
524 | "CreateReceiptRuleSet"
525 | "DeleteReceiptRuleSet"
526 | "DescribeReceiptRuleSet"
527 | "ListReceiptRuleSets"
528 | "CloneReceiptRuleSet"
529 | "SetActiveReceiptRuleSet"
530 | "ReorderReceiptRuleSet"
531 | "CreateReceiptRule"
532 | "DeleteReceiptRule"
533 | "DescribeReceiptRule"
534 | "UpdateReceiptRule"
535 | "CreateReceiptFilter"
536 | "DeleteReceiptFilter"
537 | "ListReceiptFilters" => Some("ses".to_string()),
538 "ConfirmSubscription" | "Unsubscribe" => Some("sns".to_string()),
542 _ => None,
543 }
544}
545
546fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
548 let auth = headers.get("authorization")?.to_str().ok()?;
549 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
550 Some(normalize_service_name(&info.service).to_string())
551}
552
553fn normalize_service_name(service: &str) -> &str {
565 match service {
566 "bedrock-runtime" => "bedrock",
567 "apigatewayv2" => "apigateway",
575 other => other,
576 }
577}
578
579pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
581 decode_form_urlencoded(body)
582}
583
584pub fn flatten_json_to_query(body: &Bytes) -> HashMap<String, String> {
601 let mut out = HashMap::new();
602 let Ok(value) = serde_json::from_slice::<serde_json::Value>(body) else {
603 return out;
604 };
605 if value.is_object() {
606 flatten_json_value("", &value, &mut out);
607 }
608 out
609}
610
611fn flatten_json_value(prefix: &str, value: &serde_json::Value, out: &mut HashMap<String, String>) {
612 match value {
613 serde_json::Value::Object(map) => {
614 for (k, v) in map {
615 let child = if prefix.is_empty() {
616 k.clone()
617 } else {
618 format!("{prefix}.{k}")
619 };
620 flatten_json_value(&child, v, out);
621 }
622 }
623 serde_json::Value::Array(items) => {
624 for (i, v) in items.iter().enumerate() {
625 let child = format!("{prefix}.member.{}", i + 1);
626 flatten_json_value(&child, v, out);
627 }
628 }
629 serde_json::Value::Null => {}
630 serde_json::Value::String(s) => {
631 out.insert(prefix.to_string(), s.clone());
632 }
633 serde_json::Value::Bool(b) => {
634 out.insert(prefix.to_string(), b.to_string());
635 }
636 serde_json::Value::Number(n) => {
637 out.insert(prefix.to_string(), n.to_string());
638 }
639 }
640}
641
642fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
643 let s = std::str::from_utf8(input).unwrap_or("");
644 let mut result = HashMap::new();
645 for pair in s.split('&') {
646 if pair.is_empty() {
647 continue;
648 }
649 let (key, value) = match pair.find('=') {
650 Some(pos) => (&pair[..pos], &pair[pos + 1..]),
651 None => (pair, ""),
652 };
653 result.insert(url_decode(key), url_decode(value));
654 }
655 result
656}
657
658fn url_decode(input: &str) -> String {
659 let mut result = String::with_capacity(input.len());
660 let mut bytes = input.bytes();
661 while let Some(b) = bytes.next() {
662 match b {
663 b'+' => result.push(' '),
664 b'%' => {
665 let high = bytes.next().and_then(from_hex);
666 let low = bytes.next().and_then(from_hex);
667 if let (Some(h), Some(l)) = (high, low) {
668 result.push((h << 4 | l) as char);
669 }
670 }
671 _ => result.push(b as char),
672 }
673 }
674 result
675}
676
677fn from_hex(b: u8) -> Option<u8> {
678 match b {
679 b'0'..=b'9' => Some(b - b'0'),
680 b'a'..=b'f' => Some(b - b'a' + 10),
681 b'A'..=b'F' => Some(b - b'A' + 10),
682 _ => None,
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689
690 #[test]
691 fn parse_amz_target_events() {
692 let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
693 assert_eq!(result.service, "events");
694 assert_eq!(result.action, "PutEvents");
695 assert_eq!(result.protocol, AwsProtocol::Json);
696 }
697
698 #[test]
699 fn parse_amz_target_ssm() {
700 let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
701 assert_eq!(result.service, "ssm");
702 assert_eq!(result.action, "GetParameter");
703 }
704
705 #[test]
706 fn parse_amz_target_kinesis() {
707 let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
708 assert_eq!(result.service, "kinesis");
709 assert_eq!(result.action, "ListStreams");
710 assert_eq!(result.protocol, AwsProtocol::Json);
711 }
712
713 #[test]
714 fn parse_query_body_basic() {
715 let body = Bytes::from(
716 "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
717 );
718 let params = parse_query_body(&body);
719 assert_eq!(params.get("Action").unwrap(), "SendMessage");
720 assert_eq!(params.get("MessageBody").unwrap(), "hello");
721 }
722
723 #[test]
724 fn parse_query_body_empty_returns_empty_map() {
725 let body = Bytes::from("");
726 let params = parse_query_body(&body);
727 assert!(params.is_empty());
728 }
729
730 #[test]
731 fn parse_query_body_duplicate_keys_last_wins() {
732 let body = Bytes::from("key=a&key=b");
733 let params = parse_query_body(&body);
734 assert_eq!(params.get("key").unwrap(), "b");
735 }
736
737 #[test]
738 fn parse_query_body_single_key() {
739 let body = Bytes::from("key=value");
740 let params = parse_query_body(&body);
741 assert_eq!(params.get("key").unwrap(), "value");
742 }
743
744 #[test]
745 fn parse_amz_target_ecs() {
746 let result = parse_amz_target("AmazonEC2ContainerServiceV20141113.ListClusters").unwrap();
747 assert_eq!(result.service, "ecs");
748 assert_eq!(result.action, "ListClusters");
749 assert_eq!(result.protocol, AwsProtocol::Json);
750 }
751
752 #[test]
753 fn parse_amz_target_invalid_returns_none() {
754 assert!(parse_amz_target("NoDotHere").is_none());
755 assert!(parse_amz_target("").is_none());
756 }
757
758 #[test]
759 fn parse_amz_target_cloudwatch_json() {
760 let result = parse_amz_target("GraniteServiceVersion20100801.PutMetricData").unwrap();
762 assert_eq!(result.service, "monitoring");
763 assert_eq!(result.action, "PutMetricData");
764 assert_eq!(result.protocol, AwsProtocol::Json);
765 }
766
767 #[test]
768 fn flatten_json_to_query_nested() {
769 let body = Bytes::from(
770 serde_json::json!({
771 "Namespace": "MyApp",
772 "MetricData": [{
773 "MetricName": "Latency",
774 "Value": 12.5,
775 "StatisticValues": {"SampleCount": 3, "Sum": 10},
776 "Dimensions": [{"Name": "Endpoint", "Value": "/api"}]
777 }]
778 })
779 .to_string(),
780 );
781 let flat = flatten_json_to_query(&body);
782 assert_eq!(flat.get("Namespace").unwrap(), "MyApp");
783 assert_eq!(
784 flat.get("MetricData.member.1.MetricName").unwrap(),
785 "Latency"
786 );
787 assert_eq!(flat.get("MetricData.member.1.Value").unwrap(), "12.5");
788 assert_eq!(
789 flat.get("MetricData.member.1.StatisticValues.SampleCount")
790 .unwrap(),
791 "3"
792 );
793 assert_eq!(
794 flat.get("MetricData.member.1.Dimensions.member.1.Name")
795 .unwrap(),
796 "Endpoint"
797 );
798 assert_eq!(
799 flat.get("MetricData.member.1.Dimensions.member.1.Value")
800 .unwrap(),
801 "/api"
802 );
803 }
804
805 #[test]
806 fn flatten_json_to_query_non_object_is_empty() {
807 assert!(flatten_json_to_query(&Bytes::from_static(b"[]")).is_empty());
808 assert!(flatten_json_to_query(&Bytes::from_static(b"not json")).is_empty());
809 }
810
811 #[test]
812 fn parse_amz_target_various_prefixes() {
813 assert_eq!(
814 parse_amz_target("AmazonSQS.SendMessage").unwrap().service,
815 "sqs"
816 );
817 assert_eq!(
818 parse_amz_target("AmazonSNS.Publish").unwrap().service,
819 "sns"
820 );
821 assert_eq!(
822 parse_amz_target("DynamoDB_20120810.GetItem")
823 .unwrap()
824 .service,
825 "dynamodb"
826 );
827 assert_eq!(
828 parse_amz_target("Logs_20140328.PutLogEvents")
829 .unwrap()
830 .service,
831 "logs"
832 );
833 assert_eq!(
834 parse_amz_target("secretsmanager.GetSecretValue")
835 .unwrap()
836 .service,
837 "secretsmanager"
838 );
839 assert_eq!(
840 parse_amz_target("TrentService.Encrypt").unwrap().service,
841 "kms"
842 );
843 assert_eq!(
844 parse_amz_target("AWSCognitoIdentityProviderService.InitiateAuth")
845 .unwrap()
846 .service,
847 "cognito-idp"
848 );
849 assert_eq!(
850 parse_amz_target("AWSStepFunctions.StartExecution")
851 .unwrap()
852 .service,
853 "states"
854 );
855 assert_eq!(
856 parse_amz_target("AWSOrganizationsV20161128.CreateOrganization")
857 .unwrap()
858 .service,
859 "organizations"
860 );
861 assert!(parse_amz_target("UnknownServicePrefix.Action").is_none());
862 }
863
864 #[test]
865 fn infer_service_from_action_maps_sts() {
866 assert_eq!(
867 infer_service_from_action("AssumeRole").as_deref(),
868 Some("sts")
869 );
870 assert_eq!(
871 infer_service_from_action("GetCallerIdentity").as_deref(),
872 Some("sts")
873 );
874 }
875
876 #[test]
877 fn infer_service_from_action_maps_iam() {
878 assert_eq!(
879 infer_service_from_action("CreateUser").as_deref(),
880 Some("iam")
881 );
882 assert_eq!(
883 infer_service_from_action("ListRoles").as_deref(),
884 Some("iam")
885 );
886 }
887
888 #[test]
889 fn infer_service_from_action_maps_ses() {
890 assert_eq!(
891 infer_service_from_action("SendEmail").as_deref(),
892 Some("ses")
893 );
894 assert_eq!(
895 infer_service_from_action("ListIdentities").as_deref(),
896 Some("ses")
897 );
898 }
899
900 #[test]
901 fn infer_service_from_action_maps_sns_confirmation_flow() {
902 assert_eq!(
905 infer_service_from_action("ConfirmSubscription").as_deref(),
906 Some("sns")
907 );
908 assert_eq!(
909 infer_service_from_action("Unsubscribe").as_deref(),
910 Some("sns")
911 );
912 }
913
914 #[test]
915 fn detect_service_routes_unsigned_confirm_subscription_to_sns() {
916 let mut headers = HeaderMap::new();
919 headers.insert("host", "localhost:4566".parse().unwrap());
920 let mut query_params = HashMap::new();
921 query_params.insert("Action".to_string(), "ConfirmSubscription".to_string());
922 query_params.insert(
923 "TopicArn".to_string(),
924 "arn:aws:sns:us-east-1:000000000000:t".to_string(),
925 );
926 query_params.insert("Token".to_string(), "abc123".to_string());
927
928 let detected = detect_service(&headers, &query_params, &Bytes::new())
929 .expect("ConfirmSubscription must route to a service");
930 assert_eq!(detected.service, "sns");
931 assert_eq!(detected.action, "ConfirmSubscription");
932 assert_eq!(detected.protocol, AwsProtocol::Query);
933 }
934
935 #[test]
936 fn infer_service_from_action_unknown_returns_none() {
937 assert!(infer_service_from_action("NotARealAction").is_none());
938 }
939
940 #[test]
941 fn rest_protocol_for_returns_none_for_non_rest_service() {
942 assert!(rest_protocol_for("sqs").is_none());
943 }
944
945 #[test]
946 fn url_decode_handles_percent_and_plus() {
947 assert_eq!(url_decode("hello+world"), "hello world");
948 assert_eq!(url_decode("hello%20world"), "hello world");
949 assert_eq!(url_decode("100%25"), "100%");
950 }
951
952 #[test]
953 fn url_decode_ignores_malformed_percent() {
954 assert_eq!(url_decode("%ZZ"), "");
955 }
956
957 #[test]
958 fn from_hex_valid_digits() {
959 assert_eq!(from_hex(b'0'), Some(0));
960 assert_eq!(from_hex(b'9'), Some(9));
961 assert_eq!(from_hex(b'a'), Some(10));
962 assert_eq!(from_hex(b'F'), Some(15));
963 }
964
965 #[test]
966 fn from_hex_invalid_returns_none() {
967 assert!(from_hex(b'g').is_none());
968 assert!(from_hex(b' ').is_none());
969 }
970
971 #[test]
972 fn detect_service_via_amz_target() {
973 let mut headers = HeaderMap::new();
974 headers.insert("x-amz-target", "AmazonSSM.GetParameter".parse().unwrap());
975 let query = HashMap::new();
976 let body = Bytes::new();
977 let detected = detect_service(&headers, &query, &body).unwrap();
978 assert_eq!(detected.service, "ssm");
979 assert_eq!(detected.action, "GetParameter");
980 }
981
982 #[test]
983 fn detect_service_via_query_action_with_inferred_service() {
984 let headers = HeaderMap::new();
985 let mut query = HashMap::new();
986 query.insert("Action".to_string(), "AssumeRole".to_string());
987 let body = Bytes::new();
988 let detected = detect_service(&headers, &query, &body).unwrap();
989 assert_eq!(detected.service, "sts");
990 assert_eq!(detected.action, "AssumeRole");
991 assert_eq!(detected.protocol, AwsProtocol::Query);
992 }
993
994 #[test]
995 fn detect_service_via_form_body() {
996 let headers = HeaderMap::new();
997 let query = HashMap::new();
998 let body = Bytes::from("Action=SendEmail&Source=x%40y.com");
999 let detected = detect_service(&headers, &query, &body).unwrap();
1000 assert_eq!(detected.service, "ses");
1001 assert_eq!(detected.action, "SendEmail");
1002 }
1003
1004 #[test]
1005 fn detect_service_via_sigv2_presigned() {
1006 let headers = HeaderMap::new();
1007 let mut query = HashMap::new();
1008 query.insert("AWSAccessKeyId".to_string(), "AKID".to_string());
1009 query.insert("Signature".to_string(), "sig".to_string());
1010 query.insert("Expires".to_string(), "1234567890".to_string());
1011 let body = Bytes::new();
1012 let detected = detect_service(&headers, &query, &body).unwrap();
1013 assert_eq!(detected.service, "s3");
1014 assert_eq!(detected.protocol, AwsProtocol::Rest);
1015 }
1016
1017 #[test]
1018 fn detect_service_via_sigv4_presigned_credential() {
1019 let headers = HeaderMap::new();
1020 let mut query = HashMap::new();
1021 query.insert(
1022 "X-Amz-Credential".to_string(),
1023 "AKID/20240101/us-east-1/s3/aws4_request".to_string(),
1024 );
1025 let body = Bytes::new();
1026 let detected = detect_service(&headers, &query, &body).unwrap();
1027 assert_eq!(detected.service, "s3");
1028 assert_eq!(detected.protocol, AwsProtocol::Rest);
1029 }
1030
1031 #[test]
1032 fn detect_service_unknown_returns_none() {
1033 let headers = HeaderMap::new();
1034 let query = HashMap::new();
1035 let body = Bytes::new();
1036 assert!(detect_service(&headers, &query, &body).is_none());
1037 }
1038
1039 #[test]
1040 fn normalize_service_name_aliases_apigatewayv2_to_apigateway() {
1041 assert_eq!(normalize_service_name("apigatewayv2"), "apigateway");
1046 }
1047
1048 #[test]
1049 fn normalize_service_name_aliases_bedrock_runtime_to_bedrock() {
1050 assert_eq!(normalize_service_name("bedrock-runtime"), "bedrock");
1055 }
1056
1057 #[test]
1058 fn normalize_service_name_passes_through_unaliased_services() {
1059 assert_eq!(normalize_service_name("bedrock"), "bedrock");
1063 assert_eq!(normalize_service_name("s3"), "s3");
1064 assert_eq!(normalize_service_name("lambda"), "lambda");
1065 assert_eq!(normalize_service_name(""), "");
1066 assert_eq!(
1067 normalize_service_name("unknown-future-service"),
1068 "unknown-future-service"
1069 );
1070 }
1071
1072 #[test]
1073 fn detect_service_via_authorization_header_normalizes_bedrock_runtime() {
1074 let mut headers = HeaderMap::new();
1079 headers.insert(
1080 "authorization",
1081 "AWS4-HMAC-SHA256 \
1082 Credential=AKID/20240101/us-east-1/bedrock-runtime/aws4_request, \
1083 SignedHeaders=host, Signature=abc"
1084 .parse()
1085 .unwrap(),
1086 );
1087 let query = HashMap::new();
1088 let body = Bytes::new();
1089 let detected = detect_service(&headers, &query, &body).unwrap();
1090 assert_eq!(detected.service, "bedrock");
1091 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1092 }
1093
1094 #[test]
1095 fn detect_service_via_sigv4_presigned_credential_normalizes_bedrock_runtime() {
1096 let headers = HeaderMap::new();
1100 let mut query = HashMap::new();
1101 query.insert(
1102 "X-Amz-Credential".to_string(),
1103 "AKID/20240101/us-east-1/bedrock-runtime/aws4_request".to_string(),
1104 );
1105 let body = Bytes::new();
1106 let detected = detect_service(&headers, &query, &body).unwrap();
1107 assert_eq!(detected.service, "bedrock");
1108 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1109 }
1110
1111 #[test]
1112 fn parse_routing_host_localstack_basic() {
1113 let h = parse_routing_host("sqs.us-east-1.localhost.localstack.cloud").unwrap();
1114 assert_eq!(h.service, "sqs");
1115 assert_eq!(h.region, "us-east-1");
1116 assert!(h.bucket.is_none());
1117 }
1118
1119 #[test]
1120 fn parse_routing_host_localstack_with_port() {
1121 let h = parse_routing_host("lambda.eu-west-1.localhost.localstack.cloud:4566").unwrap();
1122 assert_eq!(h.service, "lambda");
1123 assert_eq!(h.region, "eu-west-1");
1124 assert!(h.bucket.is_none());
1125 }
1126
1127 #[test]
1128 fn parse_routing_host_case_insensitive() {
1129 let h = parse_routing_host("SQS.US-EAST-1.LOCALHOST.LOCALSTACK.CLOUD:4566").unwrap();
1130 assert_eq!(h.service, "sqs");
1131 assert_eq!(h.region, "us-east-1");
1132
1133 let h = parse_routing_host("LAMBDA.US-EAST-1.AMAZONAWS.COM").unwrap();
1134 assert_eq!(h.service, "lambda");
1135 assert_eq!(h.region, "us-east-1");
1136 }
1137
1138 #[test]
1139 fn parse_routing_host_localstack_s3_virtual_hosted() {
1140 let h =
1141 parse_routing_host("my-bucket.s3.us-east-1.localhost.localstack.cloud:4566").unwrap();
1142 assert_eq!(h.service, "s3");
1143 assert_eq!(h.region, "us-east-1");
1144 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1145 }
1146
1147 #[test]
1148 fn parse_routing_host_localstack_s3_vhost_bucket_with_dots() {
1149 let h = parse_routing_host("a.b.c.s3.us-east-1.localhost.localstack.cloud").unwrap();
1150 assert_eq!(h.service, "s3");
1151 assert_eq!(h.region, "us-east-1");
1152 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1153 }
1154
1155 #[test]
1156 fn parse_routing_host_aws_service_region() {
1157 let h = parse_routing_host("sqs.us-east-1.amazonaws.com").unwrap();
1158 assert_eq!(h.service, "sqs");
1159 assert_eq!(h.region, "us-east-1");
1160 assert!(h.bucket.is_none());
1161
1162 let h = parse_routing_host("dynamodb.eu-west-2.amazonaws.com:443").unwrap();
1163 assert_eq!(h.service, "dynamodb");
1164 assert_eq!(h.region, "eu-west-2");
1165 }
1166
1167 #[test]
1168 fn parse_routing_host_aws_s3_path_style_modern() {
1169 let h = parse_routing_host("s3.us-east-1.amazonaws.com").unwrap();
1170 assert_eq!(h.service, "s3");
1171 assert_eq!(h.region, "us-east-1");
1172 assert!(h.bucket.is_none());
1173 }
1174
1175 #[test]
1176 fn parse_routing_host_aws_s3_virtual_hosted_modern() {
1177 let h = parse_routing_host("my-bucket.s3.us-east-1.amazonaws.com").unwrap();
1178 assert_eq!(h.service, "s3");
1179 assert_eq!(h.region, "us-east-1");
1180 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1181 }
1182
1183 #[test]
1184 fn parse_routing_host_aws_s3_vhost_bucket_with_dots() {
1185 let h = parse_routing_host("a.b.c.s3.us-east-1.amazonaws.com").unwrap();
1186 assert_eq!(h.service, "s3");
1187 assert_eq!(h.region, "us-east-1");
1188 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1189 }
1190
1191 #[test]
1192 fn parse_routing_host_aws_s3_legacy_global() {
1193 let h = parse_routing_host("s3.amazonaws.com").unwrap();
1196 assert_eq!(h.service, "s3");
1197 assert_eq!(h.region, "us-east-1");
1198 assert!(h.bucket.is_none());
1199
1200 let h = parse_routing_host("my-bucket.s3.amazonaws.com").unwrap();
1201 assert_eq!(h.service, "s3");
1202 assert_eq!(h.region, "us-east-1");
1203 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1204 }
1205
1206 #[test]
1207 fn parse_routing_host_aws_s3_legacy_global_dotted_bucket() {
1208 let h = parse_routing_host("a.b.c.s3.amazonaws.com").unwrap();
1211 assert_eq!(h.service, "s3");
1212 assert_eq!(h.region, "us-east-1");
1213 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1214 }
1215
1216 #[test]
1217 fn parse_routing_host_aws_s3_dash_separated() {
1218 let h = parse_routing_host("s3-us-west-2.amazonaws.com").unwrap();
1220 assert_eq!(h.service, "s3");
1221 assert_eq!(h.region, "us-west-2");
1222 assert!(h.bucket.is_none());
1223
1224 let h = parse_routing_host("my-bucket.s3-us-west-2.amazonaws.com").unwrap();
1225 assert_eq!(h.service, "s3");
1226 assert_eq!(h.region, "us-west-2");
1227 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1228 }
1229
1230 #[test]
1231 fn parse_routing_host_rejects_plain_localhost() {
1232 assert!(parse_routing_host("localhost:4566").is_none());
1233 assert!(parse_routing_host("127.0.0.1:4566").is_none());
1234 }
1235
1236 #[test]
1237 fn parse_routing_host_rejects_unknown_suffix() {
1238 assert!(parse_routing_host("sqs.us-east-1.example.com").is_none());
1239 assert!(parse_routing_host("s3.us-east-1.aws").is_none());
1240 }
1241
1242 #[test]
1243 fn parse_routing_host_empty_and_malformed_rejected() {
1244 assert!(parse_routing_host("").is_none());
1245 assert!(parse_routing_host(".localhost.localstack.cloud").is_none());
1246 assert!(parse_routing_host("..localhost.localstack.cloud").is_none());
1247 assert!(parse_routing_host("sqs.localhost.localstack.cloud").is_none());
1248 assert!(parse_routing_host("foo.bar.baz.localhost.localstack.cloud").is_none());
1249 assert!(parse_routing_host(".amazonaws.com").is_none());
1250 assert!(parse_routing_host("amazonaws.com").is_none());
1251 }
1252
1253 #[test]
1254 fn parse_routing_host_bare_s3_accesspoint_does_not_panic() {
1255 assert!(parse_routing_host("s3-accesspoint").is_none());
1259 }
1260
1261 #[test]
1262 fn detect_service_via_host_for_rest_service() {
1263 let mut headers = HeaderMap::new();
1264 headers.insert(
1265 "host",
1266 "s3.us-east-1.localhost.localstack.cloud:4566"
1267 .parse()
1268 .unwrap(),
1269 );
1270 let query = HashMap::new();
1271 let body = Bytes::new();
1272 let detected = detect_service(&headers, &query, &body).unwrap();
1273 assert_eq!(detected.service, "s3");
1274 assert_eq!(detected.protocol, AwsProtocol::Rest);
1275 }
1276
1277 #[test]
1278 fn detect_service_via_host_for_rest_json_service() {
1279 let mut headers = HeaderMap::new();
1280 headers.insert(
1281 "host",
1282 "lambda.us-east-1.localhost.localstack.cloud:4566"
1283 .parse()
1284 .unwrap(),
1285 );
1286 let query = HashMap::new();
1287 let body = Bytes::new();
1288 let detected = detect_service(&headers, &query, &body).unwrap();
1289 assert_eq!(detected.service, "lambda");
1290 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1291 }
1292
1293 #[test]
1294 fn detect_service_via_host_plus_query_action() {
1295 let mut headers = HeaderMap::new();
1296 headers.insert(
1297 "host",
1298 "sqs.us-east-1.localhost.localstack.cloud:4566"
1299 .parse()
1300 .unwrap(),
1301 );
1302 let mut query = HashMap::new();
1303 query.insert("Action".to_string(), "ListQueues".to_string());
1304 let body = Bytes::new();
1305 let detected = detect_service(&headers, &query, &body).unwrap();
1306 assert_eq!(detected.service, "sqs");
1307 assert_eq!(detected.action, "ListQueues");
1308 assert_eq!(detected.protocol, AwsProtocol::Query);
1309 }
1310
1311 #[test]
1312 fn detect_service_sigv4_wins_over_host() {
1313 let mut headers = HeaderMap::new();
1314 headers.insert(
1315 "authorization",
1316 "AWS4-HMAC-SHA256 Credential=AKID/20240101/us-east-1/s3/aws4_request, \
1317 SignedHeaders=host, Signature=abc"
1318 .parse()
1319 .unwrap(),
1320 );
1321 headers.insert(
1322 "host",
1323 "lambda.us-east-1.localhost.localstack.cloud:4566"
1324 .parse()
1325 .unwrap(),
1326 );
1327 let query = HashMap::new();
1328 let body = Bytes::new();
1329 let detected = detect_service(&headers, &query, &body).unwrap();
1330 assert_eq!(detected.service, "s3");
1332 assert_eq!(detected.protocol, AwsProtocol::Rest);
1333 }
1334
1335 #[test]
1336 fn detect_service_host_for_virtual_hosted_s3() {
1337 let mut headers = HeaderMap::new();
1338 headers.insert(
1339 "host",
1340 "my-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1341 .parse()
1342 .unwrap(),
1343 );
1344 let query = HashMap::new();
1345 let body = Bytes::new();
1346 let detected = detect_service(&headers, &query, &body).unwrap();
1347 assert_eq!(detected.service, "s3");
1348 assert_eq!(detected.protocol, AwsProtocol::Rest);
1349 }
1350}