1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8 Query,
11 Json,
14 Rest,
17 RestJson,
20}
21
22const REST_XML_SERVICES: &[&str] = &["s3", "cloudfront", "route53"];
24
25const REST_JSON_SERVICES: &[&str] = &[
27 "lambda",
28 "ses",
29 "apigateway",
30 "bedrock",
31 "bedrock-agent",
32 "bedrock-agent-runtime",
33 "scheduler",
34 "batch",
35 "pipes",
36 "rds-data",
37 "dsql",
38 "resource-groups",
39 "eks",
40 "account",
41];
42
43#[derive(Debug, Clone)]
45pub struct DetectedRequest {
46 pub service: String,
47 pub action: String,
48 pub protocol: AwsProtocol,
49}
50
51pub fn detect_service_headers_only(
58 headers: &HeaderMap,
59 query_params: &HashMap<String, String>,
60) -> Option<DetectedRequest> {
61 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
63 return parse_amz_target(target);
64 }
65 if let Some(action) = query_params.get("Action") {
66 let service = extract_service_from_auth(headers)
67 .or_else(|| infer_service_from_action(action))
68 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
69 if let Some(service) = service {
70 return Some(DetectedRequest {
71 service,
72 action: action.clone(),
73 protocol: AwsProtocol::Query,
74 });
75 }
76 }
77 if let Some(service) = extract_service_from_auth(headers) {
78 if let Some(protocol) = rest_protocol_for(&service) {
79 return Some(DetectedRequest {
80 service,
81 action: String::new(),
82 protocol,
83 });
84 }
85 }
86 if let Some(credential) = query_params.get("X-Amz-Credential") {
87 let parts: Vec<&str> = credential.split('/').collect();
88 if parts.len() >= 4 {
89 let service = normalize_service_name(parts[3]).to_string();
90 if let Some(protocol) = rest_protocol_for(&service) {
91 return Some(DetectedRequest {
92 service,
93 action: String::new(),
94 protocol,
95 });
96 }
97 }
98 }
99 if query_params.contains_key("AWSAccessKeyId")
100 && query_params.contains_key("Signature")
101 && query_params.contains_key("Expires")
102 {
103 return Some(DetectedRequest {
104 service: "s3".to_string(),
105 action: String::new(),
106 protocol: AwsProtocol::Rest,
107 });
108 }
109 if let Some(host_info) = parse_routing_host_from_headers(headers) {
110 if let Some(protocol) = rest_protocol_for(&host_info.service) {
111 return Some(DetectedRequest {
112 service: host_info.service,
113 action: String::new(),
114 protocol,
115 });
116 }
117 }
118 None
119}
120
121pub fn detect_service(
123 headers: &HeaderMap,
124 query_params: &HashMap<String, String>,
125 body: &Bytes,
126) -> Option<DetectedRequest> {
127 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
129 return parse_amz_target(target);
130 }
131
132 if let Some(action) = query_params.get("Action") {
134 let service = extract_service_from_auth(headers)
135 .or_else(|| infer_service_from_action(action))
136 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
137 if let Some(service) = service {
138 return Some(DetectedRequest {
139 service,
140 action: action.clone(),
141 protocol: AwsProtocol::Query,
142 });
143 }
144 }
145
146 {
148 let form_params = decode_form_urlencoded(body);
149
150 if let Some(action) = form_params.get("Action") {
151 let service = extract_service_from_auth(headers)
152 .or_else(|| infer_service_from_action(action))
153 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
154 if let Some(service) = service {
155 return Some(DetectedRequest {
156 service,
157 action: action.clone(),
158 protocol: AwsProtocol::Query,
159 });
160 }
161 }
162 }
163
164 if let Some(service) = extract_service_from_auth(headers) {
166 if let Some(protocol) = rest_protocol_for(&service) {
167 return Some(DetectedRequest {
168 service,
169 action: String::new(), protocol,
171 });
172 }
173 }
174
175 if let Some(credential) = query_params.get("X-Amz-Credential") {
177 let parts: Vec<&str> = credential.split('/').collect();
179 if parts.len() >= 4 {
180 let service = normalize_service_name(parts[3]).to_string();
181 if let Some(protocol) = rest_protocol_for(&service) {
182 return Some(DetectedRequest {
183 service,
184 action: String::new(),
185 protocol,
186 });
187 }
188 }
189 }
190
191 if query_params.contains_key("AWSAccessKeyId")
195 && query_params.contains_key("Signature")
196 && query_params.contains_key("Expires")
197 {
198 return Some(DetectedRequest {
199 service: "s3".to_string(),
200 action: String::new(),
201 protocol: AwsProtocol::Rest,
202 });
203 }
204
205 if let Some(host_info) = parse_routing_host_from_headers(headers) {
209 if let Some(protocol) = rest_protocol_for(&host_info.service) {
210 return Some(DetectedRequest {
211 service: host_info.service,
212 action: String::new(),
213 protocol,
214 });
215 }
216 }
217
218 None
219}
220
221#[derive(Debug, Clone, PartialEq, Eq)]
230pub struct RoutingHost {
231 pub service: String,
232 pub region: String,
233 pub bucket: Option<String>,
235}
236
237const LOCALSTACK_SUFFIX: &str = ".localhost.localstack.cloud";
238const AWS_SUFFIX: &str = ".amazonaws.com";
239
240pub fn parse_routing_host(host: &str) -> Option<RoutingHost> {
244 let hostname = host.split(':').next()?;
245 if hostname.is_empty() {
246 return None;
247 }
248 let hostname = hostname.to_ascii_lowercase();
249 if let Some(prefix) = hostname.strip_suffix(LOCALSTACK_SUFFIX) {
250 return parse_localstack_prefix(prefix);
251 }
252 if hostname == "amazonaws.com" {
253 return None;
254 }
255 if let Some(prefix) = hostname.strip_suffix(AWS_SUFFIX) {
256 return parse_aws_prefix(prefix);
257 }
258 None
259}
260
261pub fn parse_routing_host_from_headers(headers: &HeaderMap) -> Option<RoutingHost> {
263 let host = headers.get("host")?.to_str().ok()?;
264 parse_routing_host(host)
265}
266
267fn parse_localstack_prefix(prefix: &str) -> Option<RoutingHost> {
268 if prefix.is_empty() {
269 return None;
270 }
271 let labels: Vec<&str> = prefix.split('.').collect();
272 if labels.iter().any(|l| l.is_empty()) {
273 return None;
274 }
275 match labels.len() {
276 2 => Some(RoutingHost {
277 service: labels[0].to_string(),
278 region: labels[1].to_string(),
279 bucket: None,
280 }),
281 n if n >= 3 && labels[n - 2] == "s3" => {
282 let bucket = labels[..n - 2].join(".");
283 Some(RoutingHost {
284 service: "s3".to_string(),
285 region: labels[n - 1].to_string(),
286 bucket: Some(bucket),
287 })
288 }
289 n if n >= 3 && labels[n - 2] == "s3-accesspoint" => {
290 let bucket = labels[..n - 2].join(".");
291 Some(RoutingHost {
292 service: "s3".to_string(),
293 region: labels[n - 1].to_string(),
294 bucket: Some(bucket),
295 })
296 }
297 n if n >= 3 && labels[n - 2] == "s3-control" => Some(RoutingHost {
298 service: "s3".to_string(),
299 region: labels[n - 1].to_string(),
300 bucket: None,
301 }),
302 _ => None,
303 }
304}
305
306fn parse_aws_prefix(prefix: &str) -> Option<RoutingHost> {
318 if prefix.is_empty() {
319 return None;
320 }
321 let labels: Vec<&str> = prefix.split('.').collect();
322 if labels.iter().any(|l| l.is_empty()) {
323 return None;
324 }
325 let last = *labels.last()?;
326
327 if let Some(region) = last.strip_prefix("s3-") {
330 if !region.is_empty() {
331 let bucket = if labels.len() >= 2 {
332 Some(labels[..labels.len() - 1].join("."))
333 } else {
334 None
335 };
336 return Some(RoutingHost {
337 service: "s3".to_string(),
338 region: region.to_string(),
339 bucket,
340 });
341 }
342 }
343
344 if last == "s3" {
348 if labels.len() == 1 {
349 return Some(RoutingHost {
350 service: "s3".to_string(),
351 region: "us-east-1".to_string(),
352 bucket: None,
353 });
354 }
355 return Some(RoutingHost {
356 service: "s3".to_string(),
357 region: "us-east-1".to_string(),
358 bucket: Some(labels[..labels.len() - 1].join(".")),
359 });
360 }
361
362 if last == "s3-accesspoint" {
365 if labels.len() == 2 {
366 return Some(RoutingHost {
367 service: "s3".to_string(),
368 region: labels[0].to_string(),
369 bucket: None,
370 });
371 }
372 if labels.len() >= 3 {
376 let bucket = labels[..labels.len() - 2].join(".");
377 return Some(RoutingHost {
378 service: "s3".to_string(),
379 region: labels[labels.len() - 1].to_string(),
380 bucket: Some(bucket),
381 });
382 }
383 }
384
385 if labels.len() >= 2 && labels[labels.len() - 2] == "s3-control" {
388 return Some(RoutingHost {
389 service: "s3".to_string(),
390 region: last.to_string(),
391 bucket: None,
392 });
393 }
394
395 match labels.len() {
396 2 => Some(RoutingHost {
399 service: labels[0].to_string(),
400 region: labels[1].to_string(),
401 bucket: None,
402 }),
403 n if n >= 3 && labels[n - 2] == "s3" => {
405 let bucket = labels[..n - 2].join(".");
406 Some(RoutingHost {
407 service: "s3".to_string(),
408 region: labels[n - 1].to_string(),
409 bucket: Some(bucket),
410 })
411 }
412 _ => None,
413 }
414}
415
416fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
419 let (prefix, action) = target.rsplit_once('.')?;
420
421 let service = match prefix {
422 "AWSEvents" => "events",
423 "AmazonSSM" => "ssm",
424 "AmazonSQS" => "sqs",
425 "AmazonSNS" => "sns",
426 "DynamoDB_20120810" => "dynamodb",
427 "DynamoDBStreams_20120810" => "dynamodbstreams",
428 "Logs_20140328" => "logs",
429 s if s.starts_with("secretsmanager") => "secretsmanager",
430 s if s.starts_with("TrentService") => "kms",
431 s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
432 s if s.starts_with("AWSCognitoIdentityService") => "cognito-identity",
433 s if s.starts_with("Kinesis_20131202") => "kinesis",
434 s if s.starts_with("AmazonEC2ContainerRegistry_V") => "ecr",
435 s if s.starts_with("AmazonEC2ContainerServiceV") => "ecs",
436 s if s.starts_with("AWSStepFunctions") => "states",
437 s if s.starts_with("AWSOrganizationsV") => "organizations",
438 "CertificateManager" => "acm",
439 "AnyScaleFrontendService" => "application-autoscaling",
440 "AWSWAF_20190729" => "wafv2",
443 "AmazonAthena" => "athena",
444 s if s.starts_with("Firehose_") => "firehose",
445 "AWSGlue" => "glue",
446 "CloudApiService" => "cloudcontrolapi",
447 "ResourceGroupsTaggingAPI_20170126" => "tagging",
448 "AmazonMemoryDB" => "memorydb",
449 "Route53AutoNaming_v20170314" => "servicediscovery",
452 "AWSIdentityStore" => "identitystore",
454 "SWBExternalService" => "sso",
456 s if s.starts_with("GraniteServiceVersion") => "monitoring",
462 _ => return None,
463 };
464
465 Some(DetectedRequest {
466 service: service.to_string(),
467 action: action.to_string(),
468 protocol: AwsProtocol::Json,
469 })
470}
471
472fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
474 if REST_XML_SERVICES.contains(&service) {
475 Some(AwsProtocol::Rest)
476 } else if REST_JSON_SERVICES.contains(&service) {
477 Some(AwsProtocol::RestJson)
478 } else {
479 None
480 }
481}
482
483fn infer_service_from_action(action: &str) -> Option<String> {
487 match action {
488 "AssumeRole"
489 | "AssumeRoleWithSAML"
490 | "AssumeRoleWithWebIdentity"
491 | "GetCallerIdentity"
492 | "GetSessionToken"
493 | "GetFederationToken"
494 | "GetAccessKeyInfo"
495 | "DecodeAuthorizationMessage" => Some("sts".to_string()),
496 "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
497 | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
498 | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
499 | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
500 "VerifyEmailIdentity"
502 | "VerifyDomainIdentity"
503 | "VerifyDomainDkim"
504 | "ListIdentities"
505 | "GetIdentityVerificationAttributes"
506 | "GetIdentityDkimAttributes"
507 | "DeleteIdentity"
508 | "SetIdentityDkimEnabled"
509 | "SetIdentityNotificationTopic"
510 | "SetIdentityFeedbackForwardingEnabled"
511 | "GetIdentityNotificationAttributes"
512 | "GetIdentityMailFromDomainAttributes"
513 | "SetIdentityMailFromDomain"
514 | "SendEmail"
515 | "SendRawEmail"
516 | "SendTemplatedEmail"
517 | "SendBulkTemplatedEmail"
518 | "CreateTemplate"
519 | "GetTemplate"
520 | "ListTemplates"
521 | "DeleteTemplate"
522 | "UpdateTemplate"
523 | "CreateConfigurationSet"
524 | "DeleteConfigurationSet"
525 | "DescribeConfigurationSet"
526 | "ListConfigurationSets"
527 | "CreateConfigurationSetEventDestination"
528 | "UpdateConfigurationSetEventDestination"
529 | "DeleteConfigurationSetEventDestination"
530 | "GetSendQuota"
531 | "GetSendStatistics"
532 | "GetAccountSendingEnabled"
533 | "CreateReceiptRuleSet"
534 | "DeleteReceiptRuleSet"
535 | "DescribeReceiptRuleSet"
536 | "ListReceiptRuleSets"
537 | "CloneReceiptRuleSet"
538 | "SetActiveReceiptRuleSet"
539 | "ReorderReceiptRuleSet"
540 | "CreateReceiptRule"
541 | "DeleteReceiptRule"
542 | "DescribeReceiptRule"
543 | "UpdateReceiptRule"
544 | "CreateReceiptFilter"
545 | "DeleteReceiptFilter"
546 | "ListReceiptFilters" => Some("ses".to_string()),
547 "ConfirmSubscription" | "Unsubscribe" => Some("sns".to_string()),
551 _ => None,
552 }
553}
554
555fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
557 let auth = headers.get("authorization")?.to_str().ok()?;
558 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
559 Some(normalize_service_name(&info.service).to_string())
560}
561
562fn normalize_service_name(service: &str) -> &str {
574 match service {
575 "bedrock-runtime" => "bedrock",
576 "apigatewayv2" => "apigateway",
584 other => other,
585 }
586}
587
588pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
590 decode_form_urlencoded(body)
591}
592
593pub fn flatten_json_to_query(body: &Bytes) -> HashMap<String, String> {
610 let mut out = HashMap::new();
611 let Ok(value) = serde_json::from_slice::<serde_json::Value>(body) else {
612 return out;
613 };
614 if value.is_object() {
615 flatten_json_value("", &value, &mut out);
616 }
617 out
618}
619
620fn flatten_json_value(prefix: &str, value: &serde_json::Value, out: &mut HashMap<String, String>) {
621 match value {
622 serde_json::Value::Object(map) => {
623 for (k, v) in map {
624 let child = if prefix.is_empty() {
625 k.clone()
626 } else {
627 format!("{prefix}.{k}")
628 };
629 flatten_json_value(&child, v, out);
630 }
631 }
632 serde_json::Value::Array(items) => {
633 for (i, v) in items.iter().enumerate() {
634 let child = format!("{prefix}.member.{}", i + 1);
635 flatten_json_value(&child, v, out);
636 }
637 }
638 serde_json::Value::Null => {}
639 serde_json::Value::String(s) => {
640 out.insert(prefix.to_string(), s.clone());
641 }
642 serde_json::Value::Bool(b) => {
643 out.insert(prefix.to_string(), b.to_string());
644 }
645 serde_json::Value::Number(n) => {
646 out.insert(prefix.to_string(), n.to_string());
647 }
648 }
649}
650
651fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
652 let s = std::str::from_utf8(input).unwrap_or("");
653 let mut result = HashMap::new();
654 for pair in s.split('&') {
655 if pair.is_empty() {
656 continue;
657 }
658 let (key, value) = match pair.find('=') {
659 Some(pos) => (&pair[..pos], &pair[pos + 1..]),
660 None => (pair, ""),
661 };
662 result.insert(url_decode(key), url_decode(value));
663 }
664 result
665}
666
667fn url_decode(input: &str) -> String {
668 let mut result = String::with_capacity(input.len());
669 let mut bytes = input.bytes();
670 while let Some(b) = bytes.next() {
671 match b {
672 b'+' => result.push(' '),
673 b'%' => {
674 let high = bytes.next().and_then(from_hex);
675 let low = bytes.next().and_then(from_hex);
676 if let (Some(h), Some(l)) = (high, low) {
677 result.push((h << 4 | l) as char);
678 }
679 }
680 _ => result.push(b as char),
681 }
682 }
683 result
684}
685
686fn from_hex(b: u8) -> Option<u8> {
687 match b {
688 b'0'..=b'9' => Some(b - b'0'),
689 b'a'..=b'f' => Some(b - b'a' + 10),
690 b'A'..=b'F' => Some(b - b'A' + 10),
691 _ => None,
692 }
693}
694
695#[cfg(test)]
696mod tests {
697 use super::*;
698
699 #[test]
700 fn parse_amz_target_events() {
701 let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
702 assert_eq!(result.service, "events");
703 assert_eq!(result.action, "PutEvents");
704 assert_eq!(result.protocol, AwsProtocol::Json);
705 }
706
707 #[test]
708 fn parse_amz_target_ssm() {
709 let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
710 assert_eq!(result.service, "ssm");
711 assert_eq!(result.action, "GetParameter");
712 }
713
714 #[test]
715 fn parse_amz_target_kinesis() {
716 let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
717 assert_eq!(result.service, "kinesis");
718 assert_eq!(result.action, "ListStreams");
719 assert_eq!(result.protocol, AwsProtocol::Json);
720 }
721
722 #[test]
723 fn parse_query_body_basic() {
724 let body = Bytes::from(
725 "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
726 );
727 let params = parse_query_body(&body);
728 assert_eq!(params.get("Action").unwrap(), "SendMessage");
729 assert_eq!(params.get("MessageBody").unwrap(), "hello");
730 }
731
732 #[test]
733 fn parse_query_body_empty_returns_empty_map() {
734 let body = Bytes::from("");
735 let params = parse_query_body(&body);
736 assert!(params.is_empty());
737 }
738
739 #[test]
740 fn parse_query_body_duplicate_keys_last_wins() {
741 let body = Bytes::from("key=a&key=b");
742 let params = parse_query_body(&body);
743 assert_eq!(params.get("key").unwrap(), "b");
744 }
745
746 #[test]
747 fn parse_query_body_single_key() {
748 let body = Bytes::from("key=value");
749 let params = parse_query_body(&body);
750 assert_eq!(params.get("key").unwrap(), "value");
751 }
752
753 #[test]
754 fn parse_amz_target_ecs() {
755 let result = parse_amz_target("AmazonEC2ContainerServiceV20141113.ListClusters").unwrap();
756 assert_eq!(result.service, "ecs");
757 assert_eq!(result.action, "ListClusters");
758 assert_eq!(result.protocol, AwsProtocol::Json);
759 }
760
761 #[test]
762 fn parse_amz_target_invalid_returns_none() {
763 assert!(parse_amz_target("NoDotHere").is_none());
764 assert!(parse_amz_target("").is_none());
765 }
766
767 #[test]
768 fn parse_amz_target_cloudwatch_json() {
769 let result = parse_amz_target("GraniteServiceVersion20100801.PutMetricData").unwrap();
771 assert_eq!(result.service, "monitoring");
772 assert_eq!(result.action, "PutMetricData");
773 assert_eq!(result.protocol, AwsProtocol::Json);
774 }
775
776 #[test]
777 fn flatten_json_to_query_nested() {
778 let body = Bytes::from(
779 serde_json::json!({
780 "Namespace": "MyApp",
781 "MetricData": [{
782 "MetricName": "Latency",
783 "Value": 12.5,
784 "StatisticValues": {"SampleCount": 3, "Sum": 10},
785 "Dimensions": [{"Name": "Endpoint", "Value": "/api"}]
786 }]
787 })
788 .to_string(),
789 );
790 let flat = flatten_json_to_query(&body);
791 assert_eq!(flat.get("Namespace").unwrap(), "MyApp");
792 assert_eq!(
793 flat.get("MetricData.member.1.MetricName").unwrap(),
794 "Latency"
795 );
796 assert_eq!(flat.get("MetricData.member.1.Value").unwrap(), "12.5");
797 assert_eq!(
798 flat.get("MetricData.member.1.StatisticValues.SampleCount")
799 .unwrap(),
800 "3"
801 );
802 assert_eq!(
803 flat.get("MetricData.member.1.Dimensions.member.1.Name")
804 .unwrap(),
805 "Endpoint"
806 );
807 assert_eq!(
808 flat.get("MetricData.member.1.Dimensions.member.1.Value")
809 .unwrap(),
810 "/api"
811 );
812 }
813
814 #[test]
815 fn flatten_json_to_query_non_object_is_empty() {
816 assert!(flatten_json_to_query(&Bytes::from_static(b"[]")).is_empty());
817 assert!(flatten_json_to_query(&Bytes::from_static(b"not json")).is_empty());
818 }
819
820 #[test]
821 fn parse_amz_target_various_prefixes() {
822 assert_eq!(
823 parse_amz_target("AmazonSQS.SendMessage").unwrap().service,
824 "sqs"
825 );
826 assert_eq!(
827 parse_amz_target("AmazonSNS.Publish").unwrap().service,
828 "sns"
829 );
830 assert_eq!(
831 parse_amz_target("DynamoDB_20120810.GetItem")
832 .unwrap()
833 .service,
834 "dynamodb"
835 );
836 assert_eq!(
837 parse_amz_target("Logs_20140328.PutLogEvents")
838 .unwrap()
839 .service,
840 "logs"
841 );
842 assert_eq!(
843 parse_amz_target("secretsmanager.GetSecretValue")
844 .unwrap()
845 .service,
846 "secretsmanager"
847 );
848 assert_eq!(
849 parse_amz_target("TrentService.Encrypt").unwrap().service,
850 "kms"
851 );
852 assert_eq!(
853 parse_amz_target("AWSCognitoIdentityProviderService.InitiateAuth")
854 .unwrap()
855 .service,
856 "cognito-idp"
857 );
858 assert_eq!(
859 parse_amz_target("AWSStepFunctions.StartExecution")
860 .unwrap()
861 .service,
862 "states"
863 );
864 assert_eq!(
865 parse_amz_target("AWSOrganizationsV20161128.CreateOrganization")
866 .unwrap()
867 .service,
868 "organizations"
869 );
870 assert!(parse_amz_target("UnknownServicePrefix.Action").is_none());
871 }
872
873 #[test]
874 fn infer_service_from_action_maps_sts() {
875 assert_eq!(
876 infer_service_from_action("AssumeRole").as_deref(),
877 Some("sts")
878 );
879 assert_eq!(
880 infer_service_from_action("GetCallerIdentity").as_deref(),
881 Some("sts")
882 );
883 }
884
885 #[test]
886 fn infer_service_from_action_maps_iam() {
887 assert_eq!(
888 infer_service_from_action("CreateUser").as_deref(),
889 Some("iam")
890 );
891 assert_eq!(
892 infer_service_from_action("ListRoles").as_deref(),
893 Some("iam")
894 );
895 }
896
897 #[test]
898 fn infer_service_from_action_maps_ses() {
899 assert_eq!(
900 infer_service_from_action("SendEmail").as_deref(),
901 Some("ses")
902 );
903 assert_eq!(
904 infer_service_from_action("ListIdentities").as_deref(),
905 Some("ses")
906 );
907 }
908
909 #[test]
910 fn infer_service_from_action_maps_sns_confirmation_flow() {
911 assert_eq!(
914 infer_service_from_action("ConfirmSubscription").as_deref(),
915 Some("sns")
916 );
917 assert_eq!(
918 infer_service_from_action("Unsubscribe").as_deref(),
919 Some("sns")
920 );
921 }
922
923 #[test]
924 fn detect_service_routes_unsigned_confirm_subscription_to_sns() {
925 let mut headers = HeaderMap::new();
928 headers.insert("host", "localhost:4566".parse().unwrap());
929 let mut query_params = HashMap::new();
930 query_params.insert("Action".to_string(), "ConfirmSubscription".to_string());
931 query_params.insert(
932 "TopicArn".to_string(),
933 "arn:aws:sns:us-east-1:000000000000:t".to_string(),
934 );
935 query_params.insert("Token".to_string(), "abc123".to_string());
936
937 let detected = detect_service(&headers, &query_params, &Bytes::new())
938 .expect("ConfirmSubscription must route to a service");
939 assert_eq!(detected.service, "sns");
940 assert_eq!(detected.action, "ConfirmSubscription");
941 assert_eq!(detected.protocol, AwsProtocol::Query);
942 }
943
944 #[test]
945 fn infer_service_from_action_unknown_returns_none() {
946 assert!(infer_service_from_action("NotARealAction").is_none());
947 }
948
949 #[test]
950 fn rest_protocol_for_returns_none_for_non_rest_service() {
951 assert!(rest_protocol_for("sqs").is_none());
952 }
953
954 #[test]
955 fn url_decode_handles_percent_and_plus() {
956 assert_eq!(url_decode("hello+world"), "hello world");
957 assert_eq!(url_decode("hello%20world"), "hello world");
958 assert_eq!(url_decode("100%25"), "100%");
959 }
960
961 #[test]
962 fn url_decode_ignores_malformed_percent() {
963 assert_eq!(url_decode("%ZZ"), "");
964 }
965
966 #[test]
967 fn from_hex_valid_digits() {
968 assert_eq!(from_hex(b'0'), Some(0));
969 assert_eq!(from_hex(b'9'), Some(9));
970 assert_eq!(from_hex(b'a'), Some(10));
971 assert_eq!(from_hex(b'F'), Some(15));
972 }
973
974 #[test]
975 fn from_hex_invalid_returns_none() {
976 assert!(from_hex(b'g').is_none());
977 assert!(from_hex(b' ').is_none());
978 }
979
980 #[test]
981 fn detect_service_via_amz_target() {
982 let mut headers = HeaderMap::new();
983 headers.insert("x-amz-target", "AmazonSSM.GetParameter".parse().unwrap());
984 let query = HashMap::new();
985 let body = Bytes::new();
986 let detected = detect_service(&headers, &query, &body).unwrap();
987 assert_eq!(detected.service, "ssm");
988 assert_eq!(detected.action, "GetParameter");
989 }
990
991 #[test]
992 fn detect_service_via_query_action_with_inferred_service() {
993 let headers = HeaderMap::new();
994 let mut query = HashMap::new();
995 query.insert("Action".to_string(), "AssumeRole".to_string());
996 let body = Bytes::new();
997 let detected = detect_service(&headers, &query, &body).unwrap();
998 assert_eq!(detected.service, "sts");
999 assert_eq!(detected.action, "AssumeRole");
1000 assert_eq!(detected.protocol, AwsProtocol::Query);
1001 }
1002
1003 #[test]
1004 fn detect_service_via_form_body() {
1005 let headers = HeaderMap::new();
1006 let query = HashMap::new();
1007 let body = Bytes::from("Action=SendEmail&Source=x%40y.com");
1008 let detected = detect_service(&headers, &query, &body).unwrap();
1009 assert_eq!(detected.service, "ses");
1010 assert_eq!(detected.action, "SendEmail");
1011 }
1012
1013 #[test]
1014 fn detect_service_via_sigv2_presigned() {
1015 let headers = HeaderMap::new();
1016 let mut query = HashMap::new();
1017 query.insert("AWSAccessKeyId".to_string(), "AKID".to_string());
1018 query.insert("Signature".to_string(), "sig".to_string());
1019 query.insert("Expires".to_string(), "1234567890".to_string());
1020 let body = Bytes::new();
1021 let detected = detect_service(&headers, &query, &body).unwrap();
1022 assert_eq!(detected.service, "s3");
1023 assert_eq!(detected.protocol, AwsProtocol::Rest);
1024 }
1025
1026 #[test]
1027 fn detect_service_via_sigv4_presigned_credential() {
1028 let headers = HeaderMap::new();
1029 let mut query = HashMap::new();
1030 query.insert(
1031 "X-Amz-Credential".to_string(),
1032 "AKID/20240101/us-east-1/s3/aws4_request".to_string(),
1033 );
1034 let body = Bytes::new();
1035 let detected = detect_service(&headers, &query, &body).unwrap();
1036 assert_eq!(detected.service, "s3");
1037 assert_eq!(detected.protocol, AwsProtocol::Rest);
1038 }
1039
1040 #[test]
1041 fn detect_service_unknown_returns_none() {
1042 let headers = HeaderMap::new();
1043 let query = HashMap::new();
1044 let body = Bytes::new();
1045 assert!(detect_service(&headers, &query, &body).is_none());
1046 }
1047
1048 #[test]
1049 fn normalize_service_name_aliases_apigatewayv2_to_apigateway() {
1050 assert_eq!(normalize_service_name("apigatewayv2"), "apigateway");
1055 }
1056
1057 #[test]
1058 fn normalize_service_name_aliases_bedrock_runtime_to_bedrock() {
1059 assert_eq!(normalize_service_name("bedrock-runtime"), "bedrock");
1064 }
1065
1066 #[test]
1067 fn normalize_service_name_passes_through_unaliased_services() {
1068 assert_eq!(normalize_service_name("bedrock"), "bedrock");
1072 assert_eq!(normalize_service_name("s3"), "s3");
1073 assert_eq!(normalize_service_name("lambda"), "lambda");
1074 assert_eq!(normalize_service_name(""), "");
1075 assert_eq!(
1076 normalize_service_name("unknown-future-service"),
1077 "unknown-future-service"
1078 );
1079 }
1080
1081 #[test]
1082 fn detect_service_via_authorization_header_normalizes_bedrock_runtime() {
1083 let mut headers = HeaderMap::new();
1088 headers.insert(
1089 "authorization",
1090 "AWS4-HMAC-SHA256 \
1091 Credential=AKID/20240101/us-east-1/bedrock-runtime/aws4_request, \
1092 SignedHeaders=host, Signature=abc"
1093 .parse()
1094 .unwrap(),
1095 );
1096 let query = HashMap::new();
1097 let body = Bytes::new();
1098 let detected = detect_service(&headers, &query, &body).unwrap();
1099 assert_eq!(detected.service, "bedrock");
1100 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1101 }
1102
1103 #[test]
1104 fn detect_service_via_sigv4_presigned_credential_normalizes_bedrock_runtime() {
1105 let headers = HeaderMap::new();
1109 let mut query = HashMap::new();
1110 query.insert(
1111 "X-Amz-Credential".to_string(),
1112 "AKID/20240101/us-east-1/bedrock-runtime/aws4_request".to_string(),
1113 );
1114 let body = Bytes::new();
1115 let detected = detect_service(&headers, &query, &body).unwrap();
1116 assert_eq!(detected.service, "bedrock");
1117 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1118 }
1119
1120 #[test]
1121 fn parse_routing_host_localstack_basic() {
1122 let h = parse_routing_host("sqs.us-east-1.localhost.localstack.cloud").unwrap();
1123 assert_eq!(h.service, "sqs");
1124 assert_eq!(h.region, "us-east-1");
1125 assert!(h.bucket.is_none());
1126 }
1127
1128 #[test]
1129 fn parse_routing_host_localstack_with_port() {
1130 let h = parse_routing_host("lambda.eu-west-1.localhost.localstack.cloud:4566").unwrap();
1131 assert_eq!(h.service, "lambda");
1132 assert_eq!(h.region, "eu-west-1");
1133 assert!(h.bucket.is_none());
1134 }
1135
1136 #[test]
1137 fn parse_routing_host_case_insensitive() {
1138 let h = parse_routing_host("SQS.US-EAST-1.LOCALHOST.LOCALSTACK.CLOUD:4566").unwrap();
1139 assert_eq!(h.service, "sqs");
1140 assert_eq!(h.region, "us-east-1");
1141
1142 let h = parse_routing_host("LAMBDA.US-EAST-1.AMAZONAWS.COM").unwrap();
1143 assert_eq!(h.service, "lambda");
1144 assert_eq!(h.region, "us-east-1");
1145 }
1146
1147 #[test]
1148 fn parse_routing_host_localstack_s3_virtual_hosted() {
1149 let h =
1150 parse_routing_host("my-bucket.s3.us-east-1.localhost.localstack.cloud:4566").unwrap();
1151 assert_eq!(h.service, "s3");
1152 assert_eq!(h.region, "us-east-1");
1153 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1154 }
1155
1156 #[test]
1157 fn parse_routing_host_localstack_s3_vhost_bucket_with_dots() {
1158 let h = parse_routing_host("a.b.c.s3.us-east-1.localhost.localstack.cloud").unwrap();
1159 assert_eq!(h.service, "s3");
1160 assert_eq!(h.region, "us-east-1");
1161 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1162 }
1163
1164 #[test]
1165 fn parse_routing_host_aws_service_region() {
1166 let h = parse_routing_host("sqs.us-east-1.amazonaws.com").unwrap();
1167 assert_eq!(h.service, "sqs");
1168 assert_eq!(h.region, "us-east-1");
1169 assert!(h.bucket.is_none());
1170
1171 let h = parse_routing_host("dynamodb.eu-west-2.amazonaws.com:443").unwrap();
1172 assert_eq!(h.service, "dynamodb");
1173 assert_eq!(h.region, "eu-west-2");
1174 }
1175
1176 #[test]
1177 fn parse_routing_host_aws_s3_path_style_modern() {
1178 let h = parse_routing_host("s3.us-east-1.amazonaws.com").unwrap();
1179 assert_eq!(h.service, "s3");
1180 assert_eq!(h.region, "us-east-1");
1181 assert!(h.bucket.is_none());
1182 }
1183
1184 #[test]
1185 fn parse_routing_host_aws_s3_virtual_hosted_modern() {
1186 let h = parse_routing_host("my-bucket.s3.us-east-1.amazonaws.com").unwrap();
1187 assert_eq!(h.service, "s3");
1188 assert_eq!(h.region, "us-east-1");
1189 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1190 }
1191
1192 #[test]
1193 fn parse_routing_host_aws_s3_vhost_bucket_with_dots() {
1194 let h = parse_routing_host("a.b.c.s3.us-east-1.amazonaws.com").unwrap();
1195 assert_eq!(h.service, "s3");
1196 assert_eq!(h.region, "us-east-1");
1197 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1198 }
1199
1200 #[test]
1201 fn parse_routing_host_aws_s3_legacy_global() {
1202 let h = parse_routing_host("s3.amazonaws.com").unwrap();
1205 assert_eq!(h.service, "s3");
1206 assert_eq!(h.region, "us-east-1");
1207 assert!(h.bucket.is_none());
1208
1209 let h = parse_routing_host("my-bucket.s3.amazonaws.com").unwrap();
1210 assert_eq!(h.service, "s3");
1211 assert_eq!(h.region, "us-east-1");
1212 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1213 }
1214
1215 #[test]
1216 fn parse_routing_host_aws_s3_legacy_global_dotted_bucket() {
1217 let h = parse_routing_host("a.b.c.s3.amazonaws.com").unwrap();
1220 assert_eq!(h.service, "s3");
1221 assert_eq!(h.region, "us-east-1");
1222 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1223 }
1224
1225 #[test]
1226 fn parse_routing_host_aws_s3_dash_separated() {
1227 let h = parse_routing_host("s3-us-west-2.amazonaws.com").unwrap();
1229 assert_eq!(h.service, "s3");
1230 assert_eq!(h.region, "us-west-2");
1231 assert!(h.bucket.is_none());
1232
1233 let h = parse_routing_host("my-bucket.s3-us-west-2.amazonaws.com").unwrap();
1234 assert_eq!(h.service, "s3");
1235 assert_eq!(h.region, "us-west-2");
1236 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1237 }
1238
1239 #[test]
1240 fn parse_routing_host_rejects_plain_localhost() {
1241 assert!(parse_routing_host("localhost:4566").is_none());
1242 assert!(parse_routing_host("127.0.0.1:4566").is_none());
1243 }
1244
1245 #[test]
1246 fn parse_routing_host_rejects_unknown_suffix() {
1247 assert!(parse_routing_host("sqs.us-east-1.example.com").is_none());
1248 assert!(parse_routing_host("s3.us-east-1.aws").is_none());
1249 }
1250
1251 #[test]
1252 fn parse_routing_host_empty_and_malformed_rejected() {
1253 assert!(parse_routing_host("").is_none());
1254 assert!(parse_routing_host(".localhost.localstack.cloud").is_none());
1255 assert!(parse_routing_host("..localhost.localstack.cloud").is_none());
1256 assert!(parse_routing_host("sqs.localhost.localstack.cloud").is_none());
1257 assert!(parse_routing_host("foo.bar.baz.localhost.localstack.cloud").is_none());
1258 assert!(parse_routing_host(".amazonaws.com").is_none());
1259 assert!(parse_routing_host("amazonaws.com").is_none());
1260 }
1261
1262 #[test]
1263 fn parse_routing_host_bare_s3_accesspoint_does_not_panic() {
1264 assert!(parse_routing_host("s3-accesspoint").is_none());
1268 }
1269
1270 #[test]
1271 fn detect_service_via_host_for_rest_service() {
1272 let mut headers = HeaderMap::new();
1273 headers.insert(
1274 "host",
1275 "s3.us-east-1.localhost.localstack.cloud:4566"
1276 .parse()
1277 .unwrap(),
1278 );
1279 let query = HashMap::new();
1280 let body = Bytes::new();
1281 let detected = detect_service(&headers, &query, &body).unwrap();
1282 assert_eq!(detected.service, "s3");
1283 assert_eq!(detected.protocol, AwsProtocol::Rest);
1284 }
1285
1286 #[test]
1287 fn detect_service_via_host_for_rest_json_service() {
1288 let mut headers = HeaderMap::new();
1289 headers.insert(
1290 "host",
1291 "lambda.us-east-1.localhost.localstack.cloud:4566"
1292 .parse()
1293 .unwrap(),
1294 );
1295 let query = HashMap::new();
1296 let body = Bytes::new();
1297 let detected = detect_service(&headers, &query, &body).unwrap();
1298 assert_eq!(detected.service, "lambda");
1299 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1300 }
1301
1302 #[test]
1303 fn detect_service_via_host_plus_query_action() {
1304 let mut headers = HeaderMap::new();
1305 headers.insert(
1306 "host",
1307 "sqs.us-east-1.localhost.localstack.cloud:4566"
1308 .parse()
1309 .unwrap(),
1310 );
1311 let mut query = HashMap::new();
1312 query.insert("Action".to_string(), "ListQueues".to_string());
1313 let body = Bytes::new();
1314 let detected = detect_service(&headers, &query, &body).unwrap();
1315 assert_eq!(detected.service, "sqs");
1316 assert_eq!(detected.action, "ListQueues");
1317 assert_eq!(detected.protocol, AwsProtocol::Query);
1318 }
1319
1320 #[test]
1321 fn detect_service_sigv4_wins_over_host() {
1322 let mut headers = HeaderMap::new();
1323 headers.insert(
1324 "authorization",
1325 "AWS4-HMAC-SHA256 Credential=AKID/20240101/us-east-1/s3/aws4_request, \
1326 SignedHeaders=host, Signature=abc"
1327 .parse()
1328 .unwrap(),
1329 );
1330 headers.insert(
1331 "host",
1332 "lambda.us-east-1.localhost.localstack.cloud:4566"
1333 .parse()
1334 .unwrap(),
1335 );
1336 let query = HashMap::new();
1337 let body = Bytes::new();
1338 let detected = detect_service(&headers, &query, &body).unwrap();
1339 assert_eq!(detected.service, "s3");
1341 assert_eq!(detected.protocol, AwsProtocol::Rest);
1342 }
1343
1344 #[test]
1345 fn detect_service_host_for_virtual_hosted_s3() {
1346 let mut headers = HeaderMap::new();
1347 headers.insert(
1348 "host",
1349 "my-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1350 .parse()
1351 .unwrap(),
1352 );
1353 let query = HashMap::new();
1354 let body = Bytes::new();
1355 let detected = detect_service(&headers, &query, &body).unwrap();
1356 assert_eq!(detected.service, "s3");
1357 assert_eq!(detected.protocol, AwsProtocol::Rest);
1358 }
1359}