1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8 Query,
11 Json,
14 Rest,
17 RestJson,
20}
21
22const REST_XML_SERVICES: &[&str] = &["s3", "cloudfront", "route53"];
24
25const REST_JSON_SERVICES: &[&str] = &[
27 "lambda",
28 "ses",
29 "apigateway",
30 "bedrock",
31 "bedrock-agent",
32 "bedrock-agent-runtime",
33 "scheduler",
34 "batch",
35 "pipes",
36 "rds-data",
37 "dsql",
38 "resource-groups",
39];
40
41#[derive(Debug, Clone)]
43pub struct DetectedRequest {
44 pub service: String,
45 pub action: String,
46 pub protocol: AwsProtocol,
47}
48
49pub fn detect_service_headers_only(
56 headers: &HeaderMap,
57 query_params: &HashMap<String, String>,
58) -> Option<DetectedRequest> {
59 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
61 return parse_amz_target(target);
62 }
63 if let Some(action) = query_params.get("Action") {
64 let service = extract_service_from_auth(headers)
65 .or_else(|| infer_service_from_action(action))
66 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
67 if let Some(service) = service {
68 return Some(DetectedRequest {
69 service,
70 action: action.clone(),
71 protocol: AwsProtocol::Query,
72 });
73 }
74 }
75 if let Some(service) = extract_service_from_auth(headers) {
76 if let Some(protocol) = rest_protocol_for(&service) {
77 return Some(DetectedRequest {
78 service,
79 action: String::new(),
80 protocol,
81 });
82 }
83 }
84 if let Some(credential) = query_params.get("X-Amz-Credential") {
85 let parts: Vec<&str> = credential.split('/').collect();
86 if parts.len() >= 4 {
87 let service = normalize_service_name(parts[3]).to_string();
88 if let Some(protocol) = rest_protocol_for(&service) {
89 return Some(DetectedRequest {
90 service,
91 action: String::new(),
92 protocol,
93 });
94 }
95 }
96 }
97 if query_params.contains_key("AWSAccessKeyId")
98 && query_params.contains_key("Signature")
99 && query_params.contains_key("Expires")
100 {
101 return Some(DetectedRequest {
102 service: "s3".to_string(),
103 action: String::new(),
104 protocol: AwsProtocol::Rest,
105 });
106 }
107 if let Some(host_info) = parse_routing_host_from_headers(headers) {
108 if let Some(protocol) = rest_protocol_for(&host_info.service) {
109 return Some(DetectedRequest {
110 service: host_info.service,
111 action: String::new(),
112 protocol,
113 });
114 }
115 }
116 None
117}
118
119pub fn detect_service(
121 headers: &HeaderMap,
122 query_params: &HashMap<String, String>,
123 body: &Bytes,
124) -> Option<DetectedRequest> {
125 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
127 return parse_amz_target(target);
128 }
129
130 if let Some(action) = query_params.get("Action") {
132 let service = extract_service_from_auth(headers)
133 .or_else(|| infer_service_from_action(action))
134 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
135 if let Some(service) = service {
136 return Some(DetectedRequest {
137 service,
138 action: action.clone(),
139 protocol: AwsProtocol::Query,
140 });
141 }
142 }
143
144 {
146 let form_params = decode_form_urlencoded(body);
147
148 if let Some(action) = form_params.get("Action") {
149 let service = extract_service_from_auth(headers)
150 .or_else(|| infer_service_from_action(action))
151 .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
152 if let Some(service) = service {
153 return Some(DetectedRequest {
154 service,
155 action: action.clone(),
156 protocol: AwsProtocol::Query,
157 });
158 }
159 }
160 }
161
162 if let Some(service) = extract_service_from_auth(headers) {
164 if let Some(protocol) = rest_protocol_for(&service) {
165 return Some(DetectedRequest {
166 service,
167 action: String::new(), protocol,
169 });
170 }
171 }
172
173 if let Some(credential) = query_params.get("X-Amz-Credential") {
175 let parts: Vec<&str> = credential.split('/').collect();
177 if parts.len() >= 4 {
178 let service = normalize_service_name(parts[3]).to_string();
179 if let Some(protocol) = rest_protocol_for(&service) {
180 return Some(DetectedRequest {
181 service,
182 action: String::new(),
183 protocol,
184 });
185 }
186 }
187 }
188
189 if query_params.contains_key("AWSAccessKeyId")
193 && query_params.contains_key("Signature")
194 && query_params.contains_key("Expires")
195 {
196 return Some(DetectedRequest {
197 service: "s3".to_string(),
198 action: String::new(),
199 protocol: AwsProtocol::Rest,
200 });
201 }
202
203 if let Some(host_info) = parse_routing_host_from_headers(headers) {
207 if let Some(protocol) = rest_protocol_for(&host_info.service) {
208 return Some(DetectedRequest {
209 service: host_info.service,
210 action: String::new(),
211 protocol,
212 });
213 }
214 }
215
216 None
217}
218
219#[derive(Debug, Clone, PartialEq, Eq)]
228pub struct RoutingHost {
229 pub service: String,
230 pub region: String,
231 pub bucket: Option<String>,
233}
234
235const LOCALSTACK_SUFFIX: &str = ".localhost.localstack.cloud";
236const AWS_SUFFIX: &str = ".amazonaws.com";
237
238pub fn parse_routing_host(host: &str) -> Option<RoutingHost> {
242 let hostname = host.split(':').next()?;
243 if hostname.is_empty() {
244 return None;
245 }
246 let hostname = hostname.to_ascii_lowercase();
247 if let Some(prefix) = hostname.strip_suffix(LOCALSTACK_SUFFIX) {
248 return parse_localstack_prefix(prefix);
249 }
250 if hostname == "amazonaws.com" {
251 return None;
252 }
253 if let Some(prefix) = hostname.strip_suffix(AWS_SUFFIX) {
254 return parse_aws_prefix(prefix);
255 }
256 None
257}
258
259pub fn parse_routing_host_from_headers(headers: &HeaderMap) -> Option<RoutingHost> {
261 let host = headers.get("host")?.to_str().ok()?;
262 parse_routing_host(host)
263}
264
265fn parse_localstack_prefix(prefix: &str) -> Option<RoutingHost> {
266 if prefix.is_empty() {
267 return None;
268 }
269 let labels: Vec<&str> = prefix.split('.').collect();
270 if labels.iter().any(|l| l.is_empty()) {
271 return None;
272 }
273 match labels.len() {
274 2 => Some(RoutingHost {
275 service: labels[0].to_string(),
276 region: labels[1].to_string(),
277 bucket: None,
278 }),
279 n if n >= 3 && labels[n - 2] == "s3" => {
280 let bucket = labels[..n - 2].join(".");
281 Some(RoutingHost {
282 service: "s3".to_string(),
283 region: labels[n - 1].to_string(),
284 bucket: Some(bucket),
285 })
286 }
287 n if n >= 3 && labels[n - 2] == "s3-accesspoint" => {
288 let bucket = labels[..n - 2].join(".");
289 Some(RoutingHost {
290 service: "s3".to_string(),
291 region: labels[n - 1].to_string(),
292 bucket: Some(bucket),
293 })
294 }
295 n if n >= 3 && labels[n - 2] == "s3-control" => Some(RoutingHost {
296 service: "s3".to_string(),
297 region: labels[n - 1].to_string(),
298 bucket: None,
299 }),
300 _ => None,
301 }
302}
303
304fn parse_aws_prefix(prefix: &str) -> Option<RoutingHost> {
316 if prefix.is_empty() {
317 return None;
318 }
319 let labels: Vec<&str> = prefix.split('.').collect();
320 if labels.iter().any(|l| l.is_empty()) {
321 return None;
322 }
323 let last = *labels.last()?;
324
325 if let Some(region) = last.strip_prefix("s3-") {
328 if !region.is_empty() {
329 let bucket = if labels.len() >= 2 {
330 Some(labels[..labels.len() - 1].join("."))
331 } else {
332 None
333 };
334 return Some(RoutingHost {
335 service: "s3".to_string(),
336 region: region.to_string(),
337 bucket,
338 });
339 }
340 }
341
342 if last == "s3" {
346 if labels.len() == 1 {
347 return Some(RoutingHost {
348 service: "s3".to_string(),
349 region: "us-east-1".to_string(),
350 bucket: None,
351 });
352 }
353 return Some(RoutingHost {
354 service: "s3".to_string(),
355 region: "us-east-1".to_string(),
356 bucket: Some(labels[..labels.len() - 1].join(".")),
357 });
358 }
359
360 if last == "s3-accesspoint" {
363 if labels.len() == 2 {
364 return Some(RoutingHost {
365 service: "s3".to_string(),
366 region: labels[0].to_string(),
367 bucket: None,
368 });
369 }
370 if labels.len() >= 3 {
374 let bucket = labels[..labels.len() - 2].join(".");
375 return Some(RoutingHost {
376 service: "s3".to_string(),
377 region: labels[labels.len() - 1].to_string(),
378 bucket: Some(bucket),
379 });
380 }
381 }
382
383 if labels.len() >= 2 && labels[labels.len() - 2] == "s3-control" {
386 return Some(RoutingHost {
387 service: "s3".to_string(),
388 region: last.to_string(),
389 bucket: None,
390 });
391 }
392
393 match labels.len() {
394 2 => Some(RoutingHost {
397 service: labels[0].to_string(),
398 region: labels[1].to_string(),
399 bucket: None,
400 }),
401 n if n >= 3 && labels[n - 2] == "s3" => {
403 let bucket = labels[..n - 2].join(".");
404 Some(RoutingHost {
405 service: "s3".to_string(),
406 region: labels[n - 1].to_string(),
407 bucket: Some(bucket),
408 })
409 }
410 _ => None,
411 }
412}
413
414fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
417 let (prefix, action) = target.rsplit_once('.')?;
418
419 let service = match prefix {
420 "AWSEvents" => "events",
421 "AmazonSSM" => "ssm",
422 "AmazonSQS" => "sqs",
423 "AmazonSNS" => "sns",
424 "DynamoDB_20120810" => "dynamodb",
425 "DynamoDBStreams_20120810" => "dynamodbstreams",
426 "Logs_20140328" => "logs",
427 s if s.starts_with("secretsmanager") => "secretsmanager",
428 s if s.starts_with("TrentService") => "kms",
429 s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
430 s if s.starts_with("AWSCognitoIdentityService") => "cognito-identity",
431 s if s.starts_with("Kinesis_20131202") => "kinesis",
432 s if s.starts_with("AmazonEC2ContainerRegistry_V") => "ecr",
433 s if s.starts_with("AmazonEC2ContainerServiceV") => "ecs",
434 s if s.starts_with("AWSStepFunctions") => "states",
435 s if s.starts_with("AWSOrganizationsV") => "organizations",
436 "CertificateManager" => "acm",
437 "AnyScaleFrontendService" => "application-autoscaling",
438 "AWSWAF_20190729" => "wafv2",
441 "AmazonAthena" => "athena",
442 s if s.starts_with("Firehose_") => "firehose",
443 "AWSGlue" => "glue",
444 "CloudApiService" => "cloudcontrolapi",
445 s if s.starts_with("GraniteServiceVersion") => "monitoring",
451 _ => return None,
452 };
453
454 Some(DetectedRequest {
455 service: service.to_string(),
456 action: action.to_string(),
457 protocol: AwsProtocol::Json,
458 })
459}
460
461fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
463 if REST_XML_SERVICES.contains(&service) {
464 Some(AwsProtocol::Rest)
465 } else if REST_JSON_SERVICES.contains(&service) {
466 Some(AwsProtocol::RestJson)
467 } else {
468 None
469 }
470}
471
472fn infer_service_from_action(action: &str) -> Option<String> {
476 match action {
477 "AssumeRole"
478 | "AssumeRoleWithSAML"
479 | "AssumeRoleWithWebIdentity"
480 | "GetCallerIdentity"
481 | "GetSessionToken"
482 | "GetFederationToken"
483 | "GetAccessKeyInfo"
484 | "DecodeAuthorizationMessage" => Some("sts".to_string()),
485 "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
486 | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
487 | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
488 | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
489 "VerifyEmailIdentity"
491 | "VerifyDomainIdentity"
492 | "VerifyDomainDkim"
493 | "ListIdentities"
494 | "GetIdentityVerificationAttributes"
495 | "GetIdentityDkimAttributes"
496 | "DeleteIdentity"
497 | "SetIdentityDkimEnabled"
498 | "SetIdentityNotificationTopic"
499 | "SetIdentityFeedbackForwardingEnabled"
500 | "GetIdentityNotificationAttributes"
501 | "GetIdentityMailFromDomainAttributes"
502 | "SetIdentityMailFromDomain"
503 | "SendEmail"
504 | "SendRawEmail"
505 | "SendTemplatedEmail"
506 | "SendBulkTemplatedEmail"
507 | "CreateTemplate"
508 | "GetTemplate"
509 | "ListTemplates"
510 | "DeleteTemplate"
511 | "UpdateTemplate"
512 | "CreateConfigurationSet"
513 | "DeleteConfigurationSet"
514 | "DescribeConfigurationSet"
515 | "ListConfigurationSets"
516 | "CreateConfigurationSetEventDestination"
517 | "UpdateConfigurationSetEventDestination"
518 | "DeleteConfigurationSetEventDestination"
519 | "GetSendQuota"
520 | "GetSendStatistics"
521 | "GetAccountSendingEnabled"
522 | "CreateReceiptRuleSet"
523 | "DeleteReceiptRuleSet"
524 | "DescribeReceiptRuleSet"
525 | "ListReceiptRuleSets"
526 | "CloneReceiptRuleSet"
527 | "SetActiveReceiptRuleSet"
528 | "ReorderReceiptRuleSet"
529 | "CreateReceiptRule"
530 | "DeleteReceiptRule"
531 | "DescribeReceiptRule"
532 | "UpdateReceiptRule"
533 | "CreateReceiptFilter"
534 | "DeleteReceiptFilter"
535 | "ListReceiptFilters" => Some("ses".to_string()),
536 "ConfirmSubscription" | "Unsubscribe" => Some("sns".to_string()),
540 _ => None,
541 }
542}
543
544fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
546 let auth = headers.get("authorization")?.to_str().ok()?;
547 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
548 Some(normalize_service_name(&info.service).to_string())
549}
550
551fn normalize_service_name(service: &str) -> &str {
563 match service {
564 "bedrock-runtime" => "bedrock",
565 "apigatewayv2" => "apigateway",
573 other => other,
574 }
575}
576
577pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
579 decode_form_urlencoded(body)
580}
581
582pub fn flatten_json_to_query(body: &Bytes) -> HashMap<String, String> {
599 let mut out = HashMap::new();
600 let Ok(value) = serde_json::from_slice::<serde_json::Value>(body) else {
601 return out;
602 };
603 if value.is_object() {
604 flatten_json_value("", &value, &mut out);
605 }
606 out
607}
608
609fn flatten_json_value(prefix: &str, value: &serde_json::Value, out: &mut HashMap<String, String>) {
610 match value {
611 serde_json::Value::Object(map) => {
612 for (k, v) in map {
613 let child = if prefix.is_empty() {
614 k.clone()
615 } else {
616 format!("{prefix}.{k}")
617 };
618 flatten_json_value(&child, v, out);
619 }
620 }
621 serde_json::Value::Array(items) => {
622 for (i, v) in items.iter().enumerate() {
623 let child = format!("{prefix}.member.{}", i + 1);
624 flatten_json_value(&child, v, out);
625 }
626 }
627 serde_json::Value::Null => {}
628 serde_json::Value::String(s) => {
629 out.insert(prefix.to_string(), s.clone());
630 }
631 serde_json::Value::Bool(b) => {
632 out.insert(prefix.to_string(), b.to_string());
633 }
634 serde_json::Value::Number(n) => {
635 out.insert(prefix.to_string(), n.to_string());
636 }
637 }
638}
639
640fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
641 let s = std::str::from_utf8(input).unwrap_or("");
642 let mut result = HashMap::new();
643 for pair in s.split('&') {
644 if pair.is_empty() {
645 continue;
646 }
647 let (key, value) = match pair.find('=') {
648 Some(pos) => (&pair[..pos], &pair[pos + 1..]),
649 None => (pair, ""),
650 };
651 result.insert(url_decode(key), url_decode(value));
652 }
653 result
654}
655
656fn url_decode(input: &str) -> String {
657 let mut result = String::with_capacity(input.len());
658 let mut bytes = input.bytes();
659 while let Some(b) = bytes.next() {
660 match b {
661 b'+' => result.push(' '),
662 b'%' => {
663 let high = bytes.next().and_then(from_hex);
664 let low = bytes.next().and_then(from_hex);
665 if let (Some(h), Some(l)) = (high, low) {
666 result.push((h << 4 | l) as char);
667 }
668 }
669 _ => result.push(b as char),
670 }
671 }
672 result
673}
674
675fn from_hex(b: u8) -> Option<u8> {
676 match b {
677 b'0'..=b'9' => Some(b - b'0'),
678 b'a'..=b'f' => Some(b - b'a' + 10),
679 b'A'..=b'F' => Some(b - b'A' + 10),
680 _ => None,
681 }
682}
683
684#[cfg(test)]
685mod tests {
686 use super::*;
687
688 #[test]
689 fn parse_amz_target_events() {
690 let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
691 assert_eq!(result.service, "events");
692 assert_eq!(result.action, "PutEvents");
693 assert_eq!(result.protocol, AwsProtocol::Json);
694 }
695
696 #[test]
697 fn parse_amz_target_ssm() {
698 let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
699 assert_eq!(result.service, "ssm");
700 assert_eq!(result.action, "GetParameter");
701 }
702
703 #[test]
704 fn parse_amz_target_kinesis() {
705 let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
706 assert_eq!(result.service, "kinesis");
707 assert_eq!(result.action, "ListStreams");
708 assert_eq!(result.protocol, AwsProtocol::Json);
709 }
710
711 #[test]
712 fn parse_query_body_basic() {
713 let body = Bytes::from(
714 "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
715 );
716 let params = parse_query_body(&body);
717 assert_eq!(params.get("Action").unwrap(), "SendMessage");
718 assert_eq!(params.get("MessageBody").unwrap(), "hello");
719 }
720
721 #[test]
722 fn parse_query_body_empty_returns_empty_map() {
723 let body = Bytes::from("");
724 let params = parse_query_body(&body);
725 assert!(params.is_empty());
726 }
727
728 #[test]
729 fn parse_query_body_duplicate_keys_last_wins() {
730 let body = Bytes::from("key=a&key=b");
731 let params = parse_query_body(&body);
732 assert_eq!(params.get("key").unwrap(), "b");
733 }
734
735 #[test]
736 fn parse_query_body_single_key() {
737 let body = Bytes::from("key=value");
738 let params = parse_query_body(&body);
739 assert_eq!(params.get("key").unwrap(), "value");
740 }
741
742 #[test]
743 fn parse_amz_target_ecs() {
744 let result = parse_amz_target("AmazonEC2ContainerServiceV20141113.ListClusters").unwrap();
745 assert_eq!(result.service, "ecs");
746 assert_eq!(result.action, "ListClusters");
747 assert_eq!(result.protocol, AwsProtocol::Json);
748 }
749
750 #[test]
751 fn parse_amz_target_invalid_returns_none() {
752 assert!(parse_amz_target("NoDotHere").is_none());
753 assert!(parse_amz_target("").is_none());
754 }
755
756 #[test]
757 fn parse_amz_target_cloudwatch_json() {
758 let result = parse_amz_target("GraniteServiceVersion20100801.PutMetricData").unwrap();
760 assert_eq!(result.service, "monitoring");
761 assert_eq!(result.action, "PutMetricData");
762 assert_eq!(result.protocol, AwsProtocol::Json);
763 }
764
765 #[test]
766 fn flatten_json_to_query_nested() {
767 let body = Bytes::from(
768 serde_json::json!({
769 "Namespace": "MyApp",
770 "MetricData": [{
771 "MetricName": "Latency",
772 "Value": 12.5,
773 "StatisticValues": {"SampleCount": 3, "Sum": 10},
774 "Dimensions": [{"Name": "Endpoint", "Value": "/api"}]
775 }]
776 })
777 .to_string(),
778 );
779 let flat = flatten_json_to_query(&body);
780 assert_eq!(flat.get("Namespace").unwrap(), "MyApp");
781 assert_eq!(
782 flat.get("MetricData.member.1.MetricName").unwrap(),
783 "Latency"
784 );
785 assert_eq!(flat.get("MetricData.member.1.Value").unwrap(), "12.5");
786 assert_eq!(
787 flat.get("MetricData.member.1.StatisticValues.SampleCount")
788 .unwrap(),
789 "3"
790 );
791 assert_eq!(
792 flat.get("MetricData.member.1.Dimensions.member.1.Name")
793 .unwrap(),
794 "Endpoint"
795 );
796 assert_eq!(
797 flat.get("MetricData.member.1.Dimensions.member.1.Value")
798 .unwrap(),
799 "/api"
800 );
801 }
802
803 #[test]
804 fn flatten_json_to_query_non_object_is_empty() {
805 assert!(flatten_json_to_query(&Bytes::from_static(b"[]")).is_empty());
806 assert!(flatten_json_to_query(&Bytes::from_static(b"not json")).is_empty());
807 }
808
809 #[test]
810 fn parse_amz_target_various_prefixes() {
811 assert_eq!(
812 parse_amz_target("AmazonSQS.SendMessage").unwrap().service,
813 "sqs"
814 );
815 assert_eq!(
816 parse_amz_target("AmazonSNS.Publish").unwrap().service,
817 "sns"
818 );
819 assert_eq!(
820 parse_amz_target("DynamoDB_20120810.GetItem")
821 .unwrap()
822 .service,
823 "dynamodb"
824 );
825 assert_eq!(
826 parse_amz_target("Logs_20140328.PutLogEvents")
827 .unwrap()
828 .service,
829 "logs"
830 );
831 assert_eq!(
832 parse_amz_target("secretsmanager.GetSecretValue")
833 .unwrap()
834 .service,
835 "secretsmanager"
836 );
837 assert_eq!(
838 parse_amz_target("TrentService.Encrypt").unwrap().service,
839 "kms"
840 );
841 assert_eq!(
842 parse_amz_target("AWSCognitoIdentityProviderService.InitiateAuth")
843 .unwrap()
844 .service,
845 "cognito-idp"
846 );
847 assert_eq!(
848 parse_amz_target("AWSStepFunctions.StartExecution")
849 .unwrap()
850 .service,
851 "states"
852 );
853 assert_eq!(
854 parse_amz_target("AWSOrganizationsV20161128.CreateOrganization")
855 .unwrap()
856 .service,
857 "organizations"
858 );
859 assert!(parse_amz_target("UnknownServicePrefix.Action").is_none());
860 }
861
862 #[test]
863 fn infer_service_from_action_maps_sts() {
864 assert_eq!(
865 infer_service_from_action("AssumeRole").as_deref(),
866 Some("sts")
867 );
868 assert_eq!(
869 infer_service_from_action("GetCallerIdentity").as_deref(),
870 Some("sts")
871 );
872 }
873
874 #[test]
875 fn infer_service_from_action_maps_iam() {
876 assert_eq!(
877 infer_service_from_action("CreateUser").as_deref(),
878 Some("iam")
879 );
880 assert_eq!(
881 infer_service_from_action("ListRoles").as_deref(),
882 Some("iam")
883 );
884 }
885
886 #[test]
887 fn infer_service_from_action_maps_ses() {
888 assert_eq!(
889 infer_service_from_action("SendEmail").as_deref(),
890 Some("ses")
891 );
892 assert_eq!(
893 infer_service_from_action("ListIdentities").as_deref(),
894 Some("ses")
895 );
896 }
897
898 #[test]
899 fn infer_service_from_action_maps_sns_confirmation_flow() {
900 assert_eq!(
903 infer_service_from_action("ConfirmSubscription").as_deref(),
904 Some("sns")
905 );
906 assert_eq!(
907 infer_service_from_action("Unsubscribe").as_deref(),
908 Some("sns")
909 );
910 }
911
912 #[test]
913 fn detect_service_routes_unsigned_confirm_subscription_to_sns() {
914 let mut headers = HeaderMap::new();
917 headers.insert("host", "localhost:4566".parse().unwrap());
918 let mut query_params = HashMap::new();
919 query_params.insert("Action".to_string(), "ConfirmSubscription".to_string());
920 query_params.insert(
921 "TopicArn".to_string(),
922 "arn:aws:sns:us-east-1:000000000000:t".to_string(),
923 );
924 query_params.insert("Token".to_string(), "abc123".to_string());
925
926 let detected = detect_service(&headers, &query_params, &Bytes::new())
927 .expect("ConfirmSubscription must route to a service");
928 assert_eq!(detected.service, "sns");
929 assert_eq!(detected.action, "ConfirmSubscription");
930 assert_eq!(detected.protocol, AwsProtocol::Query);
931 }
932
933 #[test]
934 fn infer_service_from_action_unknown_returns_none() {
935 assert!(infer_service_from_action("NotARealAction").is_none());
936 }
937
938 #[test]
939 fn rest_protocol_for_returns_none_for_non_rest_service() {
940 assert!(rest_protocol_for("sqs").is_none());
941 }
942
943 #[test]
944 fn url_decode_handles_percent_and_plus() {
945 assert_eq!(url_decode("hello+world"), "hello world");
946 assert_eq!(url_decode("hello%20world"), "hello world");
947 assert_eq!(url_decode("100%25"), "100%");
948 }
949
950 #[test]
951 fn url_decode_ignores_malformed_percent() {
952 assert_eq!(url_decode("%ZZ"), "");
953 }
954
955 #[test]
956 fn from_hex_valid_digits() {
957 assert_eq!(from_hex(b'0'), Some(0));
958 assert_eq!(from_hex(b'9'), Some(9));
959 assert_eq!(from_hex(b'a'), Some(10));
960 assert_eq!(from_hex(b'F'), Some(15));
961 }
962
963 #[test]
964 fn from_hex_invalid_returns_none() {
965 assert!(from_hex(b'g').is_none());
966 assert!(from_hex(b' ').is_none());
967 }
968
969 #[test]
970 fn detect_service_via_amz_target() {
971 let mut headers = HeaderMap::new();
972 headers.insert("x-amz-target", "AmazonSSM.GetParameter".parse().unwrap());
973 let query = HashMap::new();
974 let body = Bytes::new();
975 let detected = detect_service(&headers, &query, &body).unwrap();
976 assert_eq!(detected.service, "ssm");
977 assert_eq!(detected.action, "GetParameter");
978 }
979
980 #[test]
981 fn detect_service_via_query_action_with_inferred_service() {
982 let headers = HeaderMap::new();
983 let mut query = HashMap::new();
984 query.insert("Action".to_string(), "AssumeRole".to_string());
985 let body = Bytes::new();
986 let detected = detect_service(&headers, &query, &body).unwrap();
987 assert_eq!(detected.service, "sts");
988 assert_eq!(detected.action, "AssumeRole");
989 assert_eq!(detected.protocol, AwsProtocol::Query);
990 }
991
992 #[test]
993 fn detect_service_via_form_body() {
994 let headers = HeaderMap::new();
995 let query = HashMap::new();
996 let body = Bytes::from("Action=SendEmail&Source=x%40y.com");
997 let detected = detect_service(&headers, &query, &body).unwrap();
998 assert_eq!(detected.service, "ses");
999 assert_eq!(detected.action, "SendEmail");
1000 }
1001
1002 #[test]
1003 fn detect_service_via_sigv2_presigned() {
1004 let headers = HeaderMap::new();
1005 let mut query = HashMap::new();
1006 query.insert("AWSAccessKeyId".to_string(), "AKID".to_string());
1007 query.insert("Signature".to_string(), "sig".to_string());
1008 query.insert("Expires".to_string(), "1234567890".to_string());
1009 let body = Bytes::new();
1010 let detected = detect_service(&headers, &query, &body).unwrap();
1011 assert_eq!(detected.service, "s3");
1012 assert_eq!(detected.protocol, AwsProtocol::Rest);
1013 }
1014
1015 #[test]
1016 fn detect_service_via_sigv4_presigned_credential() {
1017 let headers = HeaderMap::new();
1018 let mut query = HashMap::new();
1019 query.insert(
1020 "X-Amz-Credential".to_string(),
1021 "AKID/20240101/us-east-1/s3/aws4_request".to_string(),
1022 );
1023 let body = Bytes::new();
1024 let detected = detect_service(&headers, &query, &body).unwrap();
1025 assert_eq!(detected.service, "s3");
1026 assert_eq!(detected.protocol, AwsProtocol::Rest);
1027 }
1028
1029 #[test]
1030 fn detect_service_unknown_returns_none() {
1031 let headers = HeaderMap::new();
1032 let query = HashMap::new();
1033 let body = Bytes::new();
1034 assert!(detect_service(&headers, &query, &body).is_none());
1035 }
1036
1037 #[test]
1038 fn normalize_service_name_aliases_apigatewayv2_to_apigateway() {
1039 assert_eq!(normalize_service_name("apigatewayv2"), "apigateway");
1044 }
1045
1046 #[test]
1047 fn normalize_service_name_aliases_bedrock_runtime_to_bedrock() {
1048 assert_eq!(normalize_service_name("bedrock-runtime"), "bedrock");
1053 }
1054
1055 #[test]
1056 fn normalize_service_name_passes_through_unaliased_services() {
1057 assert_eq!(normalize_service_name("bedrock"), "bedrock");
1061 assert_eq!(normalize_service_name("s3"), "s3");
1062 assert_eq!(normalize_service_name("lambda"), "lambda");
1063 assert_eq!(normalize_service_name(""), "");
1064 assert_eq!(
1065 normalize_service_name("unknown-future-service"),
1066 "unknown-future-service"
1067 );
1068 }
1069
1070 #[test]
1071 fn detect_service_via_authorization_header_normalizes_bedrock_runtime() {
1072 let mut headers = HeaderMap::new();
1077 headers.insert(
1078 "authorization",
1079 "AWS4-HMAC-SHA256 \
1080 Credential=AKID/20240101/us-east-1/bedrock-runtime/aws4_request, \
1081 SignedHeaders=host, Signature=abc"
1082 .parse()
1083 .unwrap(),
1084 );
1085 let query = HashMap::new();
1086 let body = Bytes::new();
1087 let detected = detect_service(&headers, &query, &body).unwrap();
1088 assert_eq!(detected.service, "bedrock");
1089 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1090 }
1091
1092 #[test]
1093 fn detect_service_via_sigv4_presigned_credential_normalizes_bedrock_runtime() {
1094 let headers = HeaderMap::new();
1098 let mut query = HashMap::new();
1099 query.insert(
1100 "X-Amz-Credential".to_string(),
1101 "AKID/20240101/us-east-1/bedrock-runtime/aws4_request".to_string(),
1102 );
1103 let body = Bytes::new();
1104 let detected = detect_service(&headers, &query, &body).unwrap();
1105 assert_eq!(detected.service, "bedrock");
1106 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1107 }
1108
1109 #[test]
1110 fn parse_routing_host_localstack_basic() {
1111 let h = parse_routing_host("sqs.us-east-1.localhost.localstack.cloud").unwrap();
1112 assert_eq!(h.service, "sqs");
1113 assert_eq!(h.region, "us-east-1");
1114 assert!(h.bucket.is_none());
1115 }
1116
1117 #[test]
1118 fn parse_routing_host_localstack_with_port() {
1119 let h = parse_routing_host("lambda.eu-west-1.localhost.localstack.cloud:4566").unwrap();
1120 assert_eq!(h.service, "lambda");
1121 assert_eq!(h.region, "eu-west-1");
1122 assert!(h.bucket.is_none());
1123 }
1124
1125 #[test]
1126 fn parse_routing_host_case_insensitive() {
1127 let h = parse_routing_host("SQS.US-EAST-1.LOCALHOST.LOCALSTACK.CLOUD:4566").unwrap();
1128 assert_eq!(h.service, "sqs");
1129 assert_eq!(h.region, "us-east-1");
1130
1131 let h = parse_routing_host("LAMBDA.US-EAST-1.AMAZONAWS.COM").unwrap();
1132 assert_eq!(h.service, "lambda");
1133 assert_eq!(h.region, "us-east-1");
1134 }
1135
1136 #[test]
1137 fn parse_routing_host_localstack_s3_virtual_hosted() {
1138 let h =
1139 parse_routing_host("my-bucket.s3.us-east-1.localhost.localstack.cloud:4566").unwrap();
1140 assert_eq!(h.service, "s3");
1141 assert_eq!(h.region, "us-east-1");
1142 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1143 }
1144
1145 #[test]
1146 fn parse_routing_host_localstack_s3_vhost_bucket_with_dots() {
1147 let h = parse_routing_host("a.b.c.s3.us-east-1.localhost.localstack.cloud").unwrap();
1148 assert_eq!(h.service, "s3");
1149 assert_eq!(h.region, "us-east-1");
1150 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1151 }
1152
1153 #[test]
1154 fn parse_routing_host_aws_service_region() {
1155 let h = parse_routing_host("sqs.us-east-1.amazonaws.com").unwrap();
1156 assert_eq!(h.service, "sqs");
1157 assert_eq!(h.region, "us-east-1");
1158 assert!(h.bucket.is_none());
1159
1160 let h = parse_routing_host("dynamodb.eu-west-2.amazonaws.com:443").unwrap();
1161 assert_eq!(h.service, "dynamodb");
1162 assert_eq!(h.region, "eu-west-2");
1163 }
1164
1165 #[test]
1166 fn parse_routing_host_aws_s3_path_style_modern() {
1167 let h = parse_routing_host("s3.us-east-1.amazonaws.com").unwrap();
1168 assert_eq!(h.service, "s3");
1169 assert_eq!(h.region, "us-east-1");
1170 assert!(h.bucket.is_none());
1171 }
1172
1173 #[test]
1174 fn parse_routing_host_aws_s3_virtual_hosted_modern() {
1175 let h = parse_routing_host("my-bucket.s3.us-east-1.amazonaws.com").unwrap();
1176 assert_eq!(h.service, "s3");
1177 assert_eq!(h.region, "us-east-1");
1178 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1179 }
1180
1181 #[test]
1182 fn parse_routing_host_aws_s3_vhost_bucket_with_dots() {
1183 let h = parse_routing_host("a.b.c.s3.us-east-1.amazonaws.com").unwrap();
1184 assert_eq!(h.service, "s3");
1185 assert_eq!(h.region, "us-east-1");
1186 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1187 }
1188
1189 #[test]
1190 fn parse_routing_host_aws_s3_legacy_global() {
1191 let h = parse_routing_host("s3.amazonaws.com").unwrap();
1194 assert_eq!(h.service, "s3");
1195 assert_eq!(h.region, "us-east-1");
1196 assert!(h.bucket.is_none());
1197
1198 let h = parse_routing_host("my-bucket.s3.amazonaws.com").unwrap();
1199 assert_eq!(h.service, "s3");
1200 assert_eq!(h.region, "us-east-1");
1201 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1202 }
1203
1204 #[test]
1205 fn parse_routing_host_aws_s3_legacy_global_dotted_bucket() {
1206 let h = parse_routing_host("a.b.c.s3.amazonaws.com").unwrap();
1209 assert_eq!(h.service, "s3");
1210 assert_eq!(h.region, "us-east-1");
1211 assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
1212 }
1213
1214 #[test]
1215 fn parse_routing_host_aws_s3_dash_separated() {
1216 let h = parse_routing_host("s3-us-west-2.amazonaws.com").unwrap();
1218 assert_eq!(h.service, "s3");
1219 assert_eq!(h.region, "us-west-2");
1220 assert!(h.bucket.is_none());
1221
1222 let h = parse_routing_host("my-bucket.s3-us-west-2.amazonaws.com").unwrap();
1223 assert_eq!(h.service, "s3");
1224 assert_eq!(h.region, "us-west-2");
1225 assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
1226 }
1227
1228 #[test]
1229 fn parse_routing_host_rejects_plain_localhost() {
1230 assert!(parse_routing_host("localhost:4566").is_none());
1231 assert!(parse_routing_host("127.0.0.1:4566").is_none());
1232 }
1233
1234 #[test]
1235 fn parse_routing_host_rejects_unknown_suffix() {
1236 assert!(parse_routing_host("sqs.us-east-1.example.com").is_none());
1237 assert!(parse_routing_host("s3.us-east-1.aws").is_none());
1238 }
1239
1240 #[test]
1241 fn parse_routing_host_empty_and_malformed_rejected() {
1242 assert!(parse_routing_host("").is_none());
1243 assert!(parse_routing_host(".localhost.localstack.cloud").is_none());
1244 assert!(parse_routing_host("..localhost.localstack.cloud").is_none());
1245 assert!(parse_routing_host("sqs.localhost.localstack.cloud").is_none());
1246 assert!(parse_routing_host("foo.bar.baz.localhost.localstack.cloud").is_none());
1247 assert!(parse_routing_host(".amazonaws.com").is_none());
1248 assert!(parse_routing_host("amazonaws.com").is_none());
1249 }
1250
1251 #[test]
1252 fn parse_routing_host_bare_s3_accesspoint_does_not_panic() {
1253 assert!(parse_routing_host("s3-accesspoint").is_none());
1257 }
1258
1259 #[test]
1260 fn detect_service_via_host_for_rest_service() {
1261 let mut headers = HeaderMap::new();
1262 headers.insert(
1263 "host",
1264 "s3.us-east-1.localhost.localstack.cloud:4566"
1265 .parse()
1266 .unwrap(),
1267 );
1268 let query = HashMap::new();
1269 let body = Bytes::new();
1270 let detected = detect_service(&headers, &query, &body).unwrap();
1271 assert_eq!(detected.service, "s3");
1272 assert_eq!(detected.protocol, AwsProtocol::Rest);
1273 }
1274
1275 #[test]
1276 fn detect_service_via_host_for_rest_json_service() {
1277 let mut headers = HeaderMap::new();
1278 headers.insert(
1279 "host",
1280 "lambda.us-east-1.localhost.localstack.cloud:4566"
1281 .parse()
1282 .unwrap(),
1283 );
1284 let query = HashMap::new();
1285 let body = Bytes::new();
1286 let detected = detect_service(&headers, &query, &body).unwrap();
1287 assert_eq!(detected.service, "lambda");
1288 assert_eq!(detected.protocol, AwsProtocol::RestJson);
1289 }
1290
1291 #[test]
1292 fn detect_service_via_host_plus_query_action() {
1293 let mut headers = HeaderMap::new();
1294 headers.insert(
1295 "host",
1296 "sqs.us-east-1.localhost.localstack.cloud:4566"
1297 .parse()
1298 .unwrap(),
1299 );
1300 let mut query = HashMap::new();
1301 query.insert("Action".to_string(), "ListQueues".to_string());
1302 let body = Bytes::new();
1303 let detected = detect_service(&headers, &query, &body).unwrap();
1304 assert_eq!(detected.service, "sqs");
1305 assert_eq!(detected.action, "ListQueues");
1306 assert_eq!(detected.protocol, AwsProtocol::Query);
1307 }
1308
1309 #[test]
1310 fn detect_service_sigv4_wins_over_host() {
1311 let mut headers = HeaderMap::new();
1312 headers.insert(
1313 "authorization",
1314 "AWS4-HMAC-SHA256 Credential=AKID/20240101/us-east-1/s3/aws4_request, \
1315 SignedHeaders=host, Signature=abc"
1316 .parse()
1317 .unwrap(),
1318 );
1319 headers.insert(
1320 "host",
1321 "lambda.us-east-1.localhost.localstack.cloud:4566"
1322 .parse()
1323 .unwrap(),
1324 );
1325 let query = HashMap::new();
1326 let body = Bytes::new();
1327 let detected = detect_service(&headers, &query, &body).unwrap();
1328 assert_eq!(detected.service, "s3");
1330 assert_eq!(detected.protocol, AwsProtocol::Rest);
1331 }
1332
1333 #[test]
1334 fn detect_service_host_for_virtual_hosted_s3() {
1335 let mut headers = HeaderMap::new();
1336 headers.insert(
1337 "host",
1338 "my-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1339 .parse()
1340 .unwrap(),
1341 );
1342 let query = HashMap::new();
1343 let body = Bytes::new();
1344 let detected = detect_service(&headers, &query, &body).unwrap();
1345 assert_eq!(detected.service, "s3");
1346 assert_eq!(detected.protocol, AwsProtocol::Rest);
1347 }
1348}