Skip to main content

fakecloud_core/
protocol.rs

1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5/// The wire protocol used by an AWS service.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8    /// Query protocol: form-encoded body, Action param, XML response.
9    /// Used by: SQS, SNS, IAM, STS.
10    Query,
11    /// JSON protocol: JSON body, X-Amz-Target header, JSON response.
12    /// Used by: SSM, EventBridge, DynamoDB, SecretsManager, KMS, CloudWatch Logs.
13    Json,
14    /// REST protocol: HTTP method + path-based routing, XML responses.
15    /// Used by: S3, API Gateway, Route53.
16    Rest,
17    /// REST-JSON protocol: HTTP method + path-based routing, JSON responses.
18    /// Used by: Lambda, SES v2.
19    RestJson,
20}
21
22/// Services that use REST protocol with XML responses (detected from SigV4 credential scope).
23const REST_XML_SERVICES: &[&str] = &["s3", "cloudfront", "route53"];
24
25/// Services that use REST protocol with JSON responses (detected from SigV4 credential scope).
26const REST_JSON_SERVICES: &[&str] = &["lambda", "ses", "apigateway", "bedrock", "scheduler"];
27
28/// Detected service name and action from an incoming HTTP request.
29#[derive(Debug, Clone)]
30pub struct DetectedRequest {
31    pub service: String,
32    pub action: String,
33    pub protocol: AwsProtocol,
34}
35
36/// Header-only service detection. Skips the form-encoded body sniff so
37/// the dispatch path can decide whether to stream or buffer the body
38/// without first reading it. Returns `None` when only a body sniff
39/// would succeed; the caller must then fall back to [`detect_service`]
40/// after buffering. Used to opt streaming routes (S3 PutObject /
41/// UploadPart, ECR OCI v2 blob upload) out of the global body cap.
42pub fn detect_service_headers_only(
43    headers: &HeaderMap,
44    query_params: &HashMap<String, String>,
45) -> Option<DetectedRequest> {
46    // Mirrors `detect_service` minus step 3 (form-body sniff).
47    if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
48        return parse_amz_target(target);
49    }
50    if let Some(action) = query_params.get("Action") {
51        let service = extract_service_from_auth(headers)
52            .or_else(|| infer_service_from_action(action))
53            .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
54        if let Some(service) = service {
55            return Some(DetectedRequest {
56                service,
57                action: action.clone(),
58                protocol: AwsProtocol::Query,
59            });
60        }
61    }
62    if let Some(service) = extract_service_from_auth(headers) {
63        if let Some(protocol) = rest_protocol_for(&service) {
64            return Some(DetectedRequest {
65                service,
66                action: String::new(),
67                protocol,
68            });
69        }
70    }
71    if let Some(credential) = query_params.get("X-Amz-Credential") {
72        let parts: Vec<&str> = credential.split('/').collect();
73        if parts.len() >= 4 {
74            let service = parts[3].to_string();
75            if let Some(protocol) = rest_protocol_for(&service) {
76                return Some(DetectedRequest {
77                    service,
78                    action: String::new(),
79                    protocol,
80                });
81            }
82        }
83    }
84    if query_params.contains_key("AWSAccessKeyId")
85        && query_params.contains_key("Signature")
86        && query_params.contains_key("Expires")
87    {
88        return Some(DetectedRequest {
89            service: "s3".to_string(),
90            action: String::new(),
91            protocol: AwsProtocol::Rest,
92        });
93    }
94    if let Some(host_info) = parse_routing_host_from_headers(headers) {
95        if let Some(protocol) = rest_protocol_for(&host_info.service) {
96            return Some(DetectedRequest {
97                service: host_info.service,
98                action: String::new(),
99                protocol,
100            });
101        }
102    }
103    None
104}
105
106/// Detect the target service and action from HTTP request components.
107pub fn detect_service(
108    headers: &HeaderMap,
109    query_params: &HashMap<String, String>,
110    body: &Bytes,
111) -> Option<DetectedRequest> {
112    // 1. Check X-Amz-Target header (JSON protocol)
113    if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
114        return parse_amz_target(target);
115    }
116
117    // 2. Check for Query protocol (Action parameter in query string or form body)
118    if let Some(action) = query_params.get("Action") {
119        let service = extract_service_from_auth(headers)
120            .or_else(|| infer_service_from_action(action))
121            .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
122        if let Some(service) = service {
123            return Some(DetectedRequest {
124                service,
125                action: action.clone(),
126                protocol: AwsProtocol::Query,
127            });
128        }
129    }
130
131    // 3. Try form-encoded body
132    {
133        let form_params = decode_form_urlencoded(body);
134
135        if let Some(action) = form_params.get("Action") {
136            let service = extract_service_from_auth(headers)
137                .or_else(|| infer_service_from_action(action))
138                .or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
139            if let Some(service) = service {
140                return Some(DetectedRequest {
141                    service,
142                    action: action.clone(),
143                    protocol: AwsProtocol::Query,
144                });
145            }
146        }
147    }
148
149    // 4. Fallback: check auth header for REST-style services (S3, Lambda, SES, etc.)
150    if let Some(service) = extract_service_from_auth(headers) {
151        if let Some(protocol) = rest_protocol_for(&service) {
152            return Some(DetectedRequest {
153                service,
154                action: String::new(), // REST services determine action from method+path
155                protocol,
156            });
157        }
158    }
159
160    // 5. Check query params for presigned URL auth (X-Amz-Credential for SigV4)
161    if let Some(credential) = query_params.get("X-Amz-Credential") {
162        // Format: AKID/date/region/service/aws4_request
163        let parts: Vec<&str> = credential.split('/').collect();
164        if parts.len() >= 4 {
165            let service = parts[3].to_string();
166            if let Some(protocol) = rest_protocol_for(&service) {
167                return Some(DetectedRequest {
168                    service,
169                    action: String::new(),
170                    protocol,
171                });
172            }
173        }
174    }
175
176    // 6. Check for SigV2-style presigned URL (AWSAccessKeyId + Signature + Expires)
177    //    Only match when all three SigV2 presigned-URL parameters are present so
178    //    we don't accidentally claim non-S3 requests.
179    if query_params.contains_key("AWSAccessKeyId")
180        && query_params.contains_key("Signature")
181        && query_params.contains_key("Expires")
182    {
183        return Some(DetectedRequest {
184            service: "s3".to_string(),
185            action: String::new(),
186            protocol: AwsProtocol::Rest,
187        });
188    }
189
190    // 7. Fallback: unsigned REST-style request carrying a LocalStack-shaped
191    //    Host header. Lets fixtures and curl-style probes reach the right
192    //    service without SigV4; signed requests were already handled in step 4.
193    if let Some(host_info) = parse_routing_host_from_headers(headers) {
194        if let Some(protocol) = rest_protocol_for(&host_info.service) {
195            return Some(DetectedRequest {
196                service: host_info.service,
197                action: String::new(),
198                protocol,
199            });
200        }
201    }
202
203    None
204}
205
206/// Service + region (and optional bucket) decoded from a `Host` header.
207/// Covers both the LocalStack hostname convention
208/// (`<service>.<region>.localhost.localstack.cloud[:port]`,
209/// `<bucket>.s3.<region>.localhost.localstack.cloud[:port]`) and real AWS
210/// service hostnames (`<service>.<region>.amazonaws.com`, S3 path-style
211/// and virtual-hosted-style including the legacy no-region
212/// `s3.amazonaws.com` / `<bucket>.s3.amazonaws.com` forms and the older
213/// dash-separated `s3-<region>.amazonaws.com` form).
214#[derive(Debug, Clone, PartialEq, Eq)]
215pub struct RoutingHost {
216    pub service: String,
217    pub region: String,
218    /// Set only for virtual-hosted-style S3 hostnames.
219    pub bucket: Option<String>,
220}
221
222const LOCALSTACK_SUFFIX: &str = ".localhost.localstack.cloud";
223const AWS_SUFFIX: &str = ".amazonaws.com";
224
225/// Parse a `Host` header value for a LocalStack- or AWS-shaped hostname.
226/// Returns `None` for anything that doesn't match — callers fall through
227/// to their existing detection path.
228pub fn parse_routing_host(host: &str) -> Option<RoutingHost> {
229    let hostname = host.split(':').next()?;
230    if hostname.is_empty() {
231        return None;
232    }
233    let hostname = hostname.to_ascii_lowercase();
234    if let Some(prefix) = hostname.strip_suffix(LOCALSTACK_SUFFIX) {
235        return parse_localstack_prefix(prefix);
236    }
237    if hostname == "amazonaws.com" {
238        return None;
239    }
240    if let Some(prefix) = hostname.strip_suffix(AWS_SUFFIX) {
241        return parse_aws_prefix(prefix);
242    }
243    None
244}
245
246/// Pull the `Host` header and parse it with [`parse_routing_host`].
247pub fn parse_routing_host_from_headers(headers: &HeaderMap) -> Option<RoutingHost> {
248    let host = headers.get("host")?.to_str().ok()?;
249    parse_routing_host(host)
250}
251
252fn parse_localstack_prefix(prefix: &str) -> Option<RoutingHost> {
253    if prefix.is_empty() {
254        return None;
255    }
256    let labels: Vec<&str> = prefix.split('.').collect();
257    if labels.iter().any(|l| l.is_empty()) {
258        return None;
259    }
260    match labels.len() {
261        2 => Some(RoutingHost {
262            service: labels[0].to_string(),
263            region: labels[1].to_string(),
264            bucket: None,
265        }),
266        n if n >= 3 && labels[n - 2] == "s3" => {
267            let bucket = labels[..n - 2].join(".");
268            Some(RoutingHost {
269                service: "s3".to_string(),
270                region: labels[n - 1].to_string(),
271                bucket: Some(bucket),
272            })
273        }
274        _ => None,
275    }
276}
277
278/// Parse the prefix before `.amazonaws.com`.
279///
280/// Handles every variant AWS has shipped for the common REST/Query services:
281///
282/// - `<service>.<region>` — modern regional endpoint (most services).
283/// - `s3.<region>` — modern path-style S3.
284/// - `<bucket>.s3.<region>` — modern virtual-hosted S3 (bucket may contain dots).
285/// - `s3` — legacy S3 global endpoint (implicitly `us-east-1`).
286/// - `<bucket>.s3` — legacy virtual-hosted S3 (implicitly `us-east-1`).
287/// - `s3-<region>` — older dash-separated path-style S3.
288/// - `<bucket>.s3-<region>` — older dash-separated virtual-hosted S3.
289fn parse_aws_prefix(prefix: &str) -> Option<RoutingHost> {
290    if prefix.is_empty() {
291        return None;
292    }
293    let labels: Vec<&str> = prefix.split('.').collect();
294    if labels.iter().any(|l| l.is_empty()) {
295        return None;
296    }
297    let last = *labels.last()?;
298
299    // `s3-<region>` as the last label: dash-separated S3. Bucket, if any,
300    // is whatever precedes it.
301    if let Some(region) = last.strip_prefix("s3-") {
302        if !region.is_empty() {
303            let bucket = if labels.len() >= 2 {
304                Some(labels[..labels.len() - 1].join("."))
305            } else {
306                None
307            };
308            return Some(RoutingHost {
309                service: "s3".to_string(),
310                region: region.to_string(),
311                bucket,
312            });
313        }
314    }
315
316    // Legacy global S3: last label is `s3`, no region present. `s3` on its
317    // own is the path-style global endpoint; anything preceding it is the
318    // bucket (including dotted names like `a.b.s3.amazonaws.com`).
319    if last == "s3" {
320        if labels.len() == 1 {
321            return Some(RoutingHost {
322                service: "s3".to_string(),
323                region: "us-east-1".to_string(),
324                bucket: None,
325            });
326        }
327        return Some(RoutingHost {
328            service: "s3".to_string(),
329            region: "us-east-1".to_string(),
330            bucket: Some(labels[..labels.len() - 1].join(".")),
331        });
332    }
333
334    match labels.len() {
335        // `<service>.<region>` — the common case. Covers `s3.<region>`
336        // path-style S3 too, since the service label falls through here.
337        2 => Some(RoutingHost {
338            service: labels[0].to_string(),
339            region: labels[1].to_string(),
340            bucket: None,
341        }),
342        // `<bucket>.s3.<region>` — modern virtual-hosted S3.
343        n if n >= 3 && labels[n - 2] == "s3" => {
344            let bucket = labels[..n - 2].join(".");
345            Some(RoutingHost {
346                service: "s3".to_string(),
347                region: labels[n - 1].to_string(),
348                bucket: Some(bucket),
349            })
350        }
351        _ => None,
352    }
353}
354
355/// Parse `X-Amz-Target: AWSEvents.PutEvents` -> service=events, action=PutEvents
356/// Parse `X-Amz-Target: AmazonSSM.GetParameter` -> service=ssm, action=GetParameter
357fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
358    let (prefix, action) = target.rsplit_once('.')?;
359
360    let service = match prefix {
361        "AWSEvents" => "events",
362        "AmazonSSM" => "ssm",
363        "AmazonSQS" => "sqs",
364        "AmazonSNS" => "sns",
365        "DynamoDB_20120810" => "dynamodb",
366        "Logs_20140328" => "logs",
367        s if s.starts_with("secretsmanager") => "secretsmanager",
368        s if s.starts_with("TrentService") => "kms",
369        s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
370        s if s.starts_with("Kinesis_20131202") => "kinesis",
371        s if s.starts_with("AmazonEC2ContainerRegistry_V") => "ecr",
372        s if s.starts_with("AmazonEC2ContainerServiceV") => "ecs",
373        s if s.starts_with("AWSStepFunctions") => "states",
374        s if s.starts_with("AWSOrganizationsV") => "organizations",
375        "CertificateManager" => "acm",
376        "AnyScaleFrontendService" => "application-autoscaling",
377        // Match the WAFv2 target version exactly so legacy WAF Classic
378        // (`AWSWAF_*` without the `_20190729` suffix) doesn't get routed here.
379        "AWSWAF_20190729" => "wafv2",
380        "AmazonAthena" => "athena",
381        _ => return None,
382    };
383
384    Some(DetectedRequest {
385        service: service.to_string(),
386        action: action.to_string(),
387        protocol: AwsProtocol::Json,
388    })
389}
390
391/// Returns the REST protocol variant for a service, or None if not a REST service.
392fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
393    if REST_XML_SERVICES.contains(&service) {
394        Some(AwsProtocol::Rest)
395    } else if REST_JSON_SERVICES.contains(&service) {
396        Some(AwsProtocol::RestJson)
397    } else {
398        None
399    }
400}
401
402/// Infer service from the action name when no SigV4 auth is present.
403/// Some AWS operations (e.g., AssumeRoleWithSAML, AssumeRoleWithWebIdentity)
404/// do not require authentication and won't have an Authorization header.
405fn infer_service_from_action(action: &str) -> Option<String> {
406    match action {
407        "AssumeRole"
408        | "AssumeRoleWithSAML"
409        | "AssumeRoleWithWebIdentity"
410        | "GetCallerIdentity"
411        | "GetSessionToken"
412        | "GetFederationToken"
413        | "GetAccessKeyInfo"
414        | "DecodeAuthorizationMessage" => Some("sts".to_string()),
415        "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
416        | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
417        | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
418        | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
419        // SES v1 (Query protocol)
420        "VerifyEmailIdentity"
421        | "VerifyDomainIdentity"
422        | "VerifyDomainDkim"
423        | "ListIdentities"
424        | "GetIdentityVerificationAttributes"
425        | "GetIdentityDkimAttributes"
426        | "DeleteIdentity"
427        | "SetIdentityDkimEnabled"
428        | "SetIdentityNotificationTopic"
429        | "SetIdentityFeedbackForwardingEnabled"
430        | "GetIdentityNotificationAttributes"
431        | "GetIdentityMailFromDomainAttributes"
432        | "SetIdentityMailFromDomain"
433        | "SendEmail"
434        | "SendRawEmail"
435        | "SendTemplatedEmail"
436        | "SendBulkTemplatedEmail"
437        | "CreateTemplate"
438        | "GetTemplate"
439        | "ListTemplates"
440        | "DeleteTemplate"
441        | "UpdateTemplate"
442        | "CreateConfigurationSet"
443        | "DeleteConfigurationSet"
444        | "DescribeConfigurationSet"
445        | "ListConfigurationSets"
446        | "CreateConfigurationSetEventDestination"
447        | "UpdateConfigurationSetEventDestination"
448        | "DeleteConfigurationSetEventDestination"
449        | "GetSendQuota"
450        | "GetSendStatistics"
451        | "GetAccountSendingEnabled"
452        | "CreateReceiptRuleSet"
453        | "DeleteReceiptRuleSet"
454        | "DescribeReceiptRuleSet"
455        | "ListReceiptRuleSets"
456        | "CloneReceiptRuleSet"
457        | "SetActiveReceiptRuleSet"
458        | "ReorderReceiptRuleSet"
459        | "CreateReceiptRule"
460        | "DeleteReceiptRule"
461        | "DescribeReceiptRule"
462        | "UpdateReceiptRule"
463        | "CreateReceiptFilter"
464        | "DeleteReceiptFilter"
465        | "ListReceiptFilters" => Some("ses".to_string()),
466        _ => None,
467    }
468}
469
470/// Extract service name from the SigV4 Authorization header credential scope.
471fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
472    let auth = headers.get("authorization")?.to_str().ok()?;
473    let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
474    Some(info.service)
475}
476
477/// Parse form-encoded body into key-value pairs.
478pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
479    decode_form_urlencoded(body)
480}
481
482fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
483    let s = std::str::from_utf8(input).unwrap_or("");
484    let mut result = HashMap::new();
485    for pair in s.split('&') {
486        if pair.is_empty() {
487            continue;
488        }
489        let (key, value) = match pair.find('=') {
490            Some(pos) => (&pair[..pos], &pair[pos + 1..]),
491            None => (pair, ""),
492        };
493        result.insert(url_decode(key), url_decode(value));
494    }
495    result
496}
497
498fn url_decode(input: &str) -> String {
499    let mut result = String::with_capacity(input.len());
500    let mut bytes = input.bytes();
501    while let Some(b) = bytes.next() {
502        match b {
503            b'+' => result.push(' '),
504            b'%' => {
505                let high = bytes.next().and_then(from_hex);
506                let low = bytes.next().and_then(from_hex);
507                if let (Some(h), Some(l)) = (high, low) {
508                    result.push((h << 4 | l) as char);
509                }
510            }
511            _ => result.push(b as char),
512        }
513    }
514    result
515}
516
517fn from_hex(b: u8) -> Option<u8> {
518    match b {
519        b'0'..=b'9' => Some(b - b'0'),
520        b'a'..=b'f' => Some(b - b'a' + 10),
521        b'A'..=b'F' => Some(b - b'A' + 10),
522        _ => None,
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529
530    #[test]
531    fn parse_amz_target_events() {
532        let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
533        assert_eq!(result.service, "events");
534        assert_eq!(result.action, "PutEvents");
535        assert_eq!(result.protocol, AwsProtocol::Json);
536    }
537
538    #[test]
539    fn parse_amz_target_ssm() {
540        let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
541        assert_eq!(result.service, "ssm");
542        assert_eq!(result.action, "GetParameter");
543    }
544
545    #[test]
546    fn parse_amz_target_kinesis() {
547        let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
548        assert_eq!(result.service, "kinesis");
549        assert_eq!(result.action, "ListStreams");
550        assert_eq!(result.protocol, AwsProtocol::Json);
551    }
552
553    #[test]
554    fn parse_query_body_basic() {
555        let body = Bytes::from(
556            "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
557        );
558        let params = parse_query_body(&body);
559        assert_eq!(params.get("Action").unwrap(), "SendMessage");
560        assert_eq!(params.get("MessageBody").unwrap(), "hello");
561    }
562
563    #[test]
564    fn parse_query_body_empty_returns_empty_map() {
565        let body = Bytes::from("");
566        let params = parse_query_body(&body);
567        assert!(params.is_empty());
568    }
569
570    #[test]
571    fn parse_query_body_duplicate_keys_last_wins() {
572        let body = Bytes::from("key=a&key=b");
573        let params = parse_query_body(&body);
574        assert_eq!(params.get("key").unwrap(), "b");
575    }
576
577    #[test]
578    fn parse_query_body_single_key() {
579        let body = Bytes::from("key=value");
580        let params = parse_query_body(&body);
581        assert_eq!(params.get("key").unwrap(), "value");
582    }
583
584    #[test]
585    fn parse_amz_target_ecs() {
586        let result = parse_amz_target("AmazonEC2ContainerServiceV20141113.ListClusters").unwrap();
587        assert_eq!(result.service, "ecs");
588        assert_eq!(result.action, "ListClusters");
589        assert_eq!(result.protocol, AwsProtocol::Json);
590    }
591
592    #[test]
593    fn parse_amz_target_invalid_returns_none() {
594        assert!(parse_amz_target("NoDotHere").is_none());
595        assert!(parse_amz_target("").is_none());
596    }
597
598    #[test]
599    fn parse_amz_target_various_prefixes() {
600        assert_eq!(
601            parse_amz_target("AmazonSQS.SendMessage").unwrap().service,
602            "sqs"
603        );
604        assert_eq!(
605            parse_amz_target("AmazonSNS.Publish").unwrap().service,
606            "sns"
607        );
608        assert_eq!(
609            parse_amz_target("DynamoDB_20120810.GetItem")
610                .unwrap()
611                .service,
612            "dynamodb"
613        );
614        assert_eq!(
615            parse_amz_target("Logs_20140328.PutLogEvents")
616                .unwrap()
617                .service,
618            "logs"
619        );
620        assert_eq!(
621            parse_amz_target("secretsmanager.GetSecretValue")
622                .unwrap()
623                .service,
624            "secretsmanager"
625        );
626        assert_eq!(
627            parse_amz_target("TrentService.Encrypt").unwrap().service,
628            "kms"
629        );
630        assert_eq!(
631            parse_amz_target("AWSCognitoIdentityProviderService.InitiateAuth")
632                .unwrap()
633                .service,
634            "cognito-idp"
635        );
636        assert_eq!(
637            parse_amz_target("AWSStepFunctions.StartExecution")
638                .unwrap()
639                .service,
640            "states"
641        );
642        assert_eq!(
643            parse_amz_target("AWSOrganizationsV20161128.CreateOrganization")
644                .unwrap()
645                .service,
646            "organizations"
647        );
648        assert!(parse_amz_target("UnknownServicePrefix.Action").is_none());
649    }
650
651    #[test]
652    fn infer_service_from_action_maps_sts() {
653        assert_eq!(
654            infer_service_from_action("AssumeRole").as_deref(),
655            Some("sts")
656        );
657        assert_eq!(
658            infer_service_from_action("GetCallerIdentity").as_deref(),
659            Some("sts")
660        );
661    }
662
663    #[test]
664    fn infer_service_from_action_maps_iam() {
665        assert_eq!(
666            infer_service_from_action("CreateUser").as_deref(),
667            Some("iam")
668        );
669        assert_eq!(
670            infer_service_from_action("ListRoles").as_deref(),
671            Some("iam")
672        );
673    }
674
675    #[test]
676    fn infer_service_from_action_maps_ses() {
677        assert_eq!(
678            infer_service_from_action("SendEmail").as_deref(),
679            Some("ses")
680        );
681        assert_eq!(
682            infer_service_from_action("ListIdentities").as_deref(),
683            Some("ses")
684        );
685    }
686
687    #[test]
688    fn infer_service_from_action_unknown_returns_none() {
689        assert!(infer_service_from_action("NotARealAction").is_none());
690    }
691
692    #[test]
693    fn rest_protocol_for_returns_none_for_non_rest_service() {
694        assert!(rest_protocol_for("sqs").is_none());
695    }
696
697    #[test]
698    fn url_decode_handles_percent_and_plus() {
699        assert_eq!(url_decode("hello+world"), "hello world");
700        assert_eq!(url_decode("hello%20world"), "hello world");
701        assert_eq!(url_decode("100%25"), "100%");
702    }
703
704    #[test]
705    fn url_decode_ignores_malformed_percent() {
706        assert_eq!(url_decode("%ZZ"), "");
707    }
708
709    #[test]
710    fn from_hex_valid_digits() {
711        assert_eq!(from_hex(b'0'), Some(0));
712        assert_eq!(from_hex(b'9'), Some(9));
713        assert_eq!(from_hex(b'a'), Some(10));
714        assert_eq!(from_hex(b'F'), Some(15));
715    }
716
717    #[test]
718    fn from_hex_invalid_returns_none() {
719        assert!(from_hex(b'g').is_none());
720        assert!(from_hex(b' ').is_none());
721    }
722
723    #[test]
724    fn detect_service_via_amz_target() {
725        let mut headers = HeaderMap::new();
726        headers.insert("x-amz-target", "AmazonSSM.GetParameter".parse().unwrap());
727        let query = HashMap::new();
728        let body = Bytes::new();
729        let detected = detect_service(&headers, &query, &body).unwrap();
730        assert_eq!(detected.service, "ssm");
731        assert_eq!(detected.action, "GetParameter");
732    }
733
734    #[test]
735    fn detect_service_via_query_action_with_inferred_service() {
736        let headers = HeaderMap::new();
737        let mut query = HashMap::new();
738        query.insert("Action".to_string(), "AssumeRole".to_string());
739        let body = Bytes::new();
740        let detected = detect_service(&headers, &query, &body).unwrap();
741        assert_eq!(detected.service, "sts");
742        assert_eq!(detected.action, "AssumeRole");
743        assert_eq!(detected.protocol, AwsProtocol::Query);
744    }
745
746    #[test]
747    fn detect_service_via_form_body() {
748        let headers = HeaderMap::new();
749        let query = HashMap::new();
750        let body = Bytes::from("Action=SendEmail&Source=x%40y.com");
751        let detected = detect_service(&headers, &query, &body).unwrap();
752        assert_eq!(detected.service, "ses");
753        assert_eq!(detected.action, "SendEmail");
754    }
755
756    #[test]
757    fn detect_service_via_sigv2_presigned() {
758        let headers = HeaderMap::new();
759        let mut query = HashMap::new();
760        query.insert("AWSAccessKeyId".to_string(), "AKID".to_string());
761        query.insert("Signature".to_string(), "sig".to_string());
762        query.insert("Expires".to_string(), "1234567890".to_string());
763        let body = Bytes::new();
764        let detected = detect_service(&headers, &query, &body).unwrap();
765        assert_eq!(detected.service, "s3");
766        assert_eq!(detected.protocol, AwsProtocol::Rest);
767    }
768
769    #[test]
770    fn detect_service_via_sigv4_presigned_credential() {
771        let headers = HeaderMap::new();
772        let mut query = HashMap::new();
773        query.insert(
774            "X-Amz-Credential".to_string(),
775            "AKID/20240101/us-east-1/s3/aws4_request".to_string(),
776        );
777        let body = Bytes::new();
778        let detected = detect_service(&headers, &query, &body).unwrap();
779        assert_eq!(detected.service, "s3");
780        assert_eq!(detected.protocol, AwsProtocol::Rest);
781    }
782
783    #[test]
784    fn detect_service_unknown_returns_none() {
785        let headers = HeaderMap::new();
786        let query = HashMap::new();
787        let body = Bytes::new();
788        assert!(detect_service(&headers, &query, &body).is_none());
789    }
790
791    #[test]
792    fn parse_routing_host_localstack_basic() {
793        let h = parse_routing_host("sqs.us-east-1.localhost.localstack.cloud").unwrap();
794        assert_eq!(h.service, "sqs");
795        assert_eq!(h.region, "us-east-1");
796        assert!(h.bucket.is_none());
797    }
798
799    #[test]
800    fn parse_routing_host_localstack_with_port() {
801        let h = parse_routing_host("lambda.eu-west-1.localhost.localstack.cloud:4566").unwrap();
802        assert_eq!(h.service, "lambda");
803        assert_eq!(h.region, "eu-west-1");
804        assert!(h.bucket.is_none());
805    }
806
807    #[test]
808    fn parse_routing_host_case_insensitive() {
809        let h = parse_routing_host("SQS.US-EAST-1.LOCALHOST.LOCALSTACK.CLOUD:4566").unwrap();
810        assert_eq!(h.service, "sqs");
811        assert_eq!(h.region, "us-east-1");
812
813        let h = parse_routing_host("LAMBDA.US-EAST-1.AMAZONAWS.COM").unwrap();
814        assert_eq!(h.service, "lambda");
815        assert_eq!(h.region, "us-east-1");
816    }
817
818    #[test]
819    fn parse_routing_host_localstack_s3_virtual_hosted() {
820        let h =
821            parse_routing_host("my-bucket.s3.us-east-1.localhost.localstack.cloud:4566").unwrap();
822        assert_eq!(h.service, "s3");
823        assert_eq!(h.region, "us-east-1");
824        assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
825    }
826
827    #[test]
828    fn parse_routing_host_localstack_s3_vhost_bucket_with_dots() {
829        let h = parse_routing_host("a.b.c.s3.us-east-1.localhost.localstack.cloud").unwrap();
830        assert_eq!(h.service, "s3");
831        assert_eq!(h.region, "us-east-1");
832        assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
833    }
834
835    #[test]
836    fn parse_routing_host_aws_service_region() {
837        let h = parse_routing_host("sqs.us-east-1.amazonaws.com").unwrap();
838        assert_eq!(h.service, "sqs");
839        assert_eq!(h.region, "us-east-1");
840        assert!(h.bucket.is_none());
841
842        let h = parse_routing_host("dynamodb.eu-west-2.amazonaws.com:443").unwrap();
843        assert_eq!(h.service, "dynamodb");
844        assert_eq!(h.region, "eu-west-2");
845    }
846
847    #[test]
848    fn parse_routing_host_aws_s3_path_style_modern() {
849        let h = parse_routing_host("s3.us-east-1.amazonaws.com").unwrap();
850        assert_eq!(h.service, "s3");
851        assert_eq!(h.region, "us-east-1");
852        assert!(h.bucket.is_none());
853    }
854
855    #[test]
856    fn parse_routing_host_aws_s3_virtual_hosted_modern() {
857        let h = parse_routing_host("my-bucket.s3.us-east-1.amazonaws.com").unwrap();
858        assert_eq!(h.service, "s3");
859        assert_eq!(h.region, "us-east-1");
860        assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
861    }
862
863    #[test]
864    fn parse_routing_host_aws_s3_vhost_bucket_with_dots() {
865        let h = parse_routing_host("a.b.c.s3.us-east-1.amazonaws.com").unwrap();
866        assert_eq!(h.service, "s3");
867        assert_eq!(h.region, "us-east-1");
868        assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
869    }
870
871    #[test]
872    fn parse_routing_host_aws_s3_legacy_global() {
873        // `s3.amazonaws.com` (no region) is the legacy S3 global endpoint —
874        // AWS treats it as us-east-1 for both path-style and virtual-hosted.
875        let h = parse_routing_host("s3.amazonaws.com").unwrap();
876        assert_eq!(h.service, "s3");
877        assert_eq!(h.region, "us-east-1");
878        assert!(h.bucket.is_none());
879
880        let h = parse_routing_host("my-bucket.s3.amazonaws.com").unwrap();
881        assert_eq!(h.service, "s3");
882        assert_eq!(h.region, "us-east-1");
883        assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
884    }
885
886    #[test]
887    fn parse_routing_host_aws_s3_legacy_global_dotted_bucket() {
888        // AWS allows buckets with dots (e.g. `a.b.c`) and still serves them
889        // via the legacy `<bucket>.s3.amazonaws.com` global endpoint.
890        let h = parse_routing_host("a.b.c.s3.amazonaws.com").unwrap();
891        assert_eq!(h.service, "s3");
892        assert_eq!(h.region, "us-east-1");
893        assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
894    }
895
896    #[test]
897    fn parse_routing_host_aws_s3_dash_separated() {
898        // Older dash-separated form still served by AWS.
899        let h = parse_routing_host("s3-us-west-2.amazonaws.com").unwrap();
900        assert_eq!(h.service, "s3");
901        assert_eq!(h.region, "us-west-2");
902        assert!(h.bucket.is_none());
903
904        let h = parse_routing_host("my-bucket.s3-us-west-2.amazonaws.com").unwrap();
905        assert_eq!(h.service, "s3");
906        assert_eq!(h.region, "us-west-2");
907        assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
908    }
909
910    #[test]
911    fn parse_routing_host_rejects_plain_localhost() {
912        assert!(parse_routing_host("localhost:4566").is_none());
913        assert!(parse_routing_host("127.0.0.1:4566").is_none());
914    }
915
916    #[test]
917    fn parse_routing_host_rejects_unknown_suffix() {
918        assert!(parse_routing_host("sqs.us-east-1.example.com").is_none());
919        assert!(parse_routing_host("s3.us-east-1.aws").is_none());
920    }
921
922    #[test]
923    fn parse_routing_host_empty_and_malformed_rejected() {
924        assert!(parse_routing_host("").is_none());
925        assert!(parse_routing_host(".localhost.localstack.cloud").is_none());
926        assert!(parse_routing_host("..localhost.localstack.cloud").is_none());
927        assert!(parse_routing_host("sqs.localhost.localstack.cloud").is_none());
928        assert!(parse_routing_host("foo.bar.baz.localhost.localstack.cloud").is_none());
929        assert!(parse_routing_host(".amazonaws.com").is_none());
930        assert!(parse_routing_host("amazonaws.com").is_none());
931    }
932
933    #[test]
934    fn detect_service_via_host_for_rest_service() {
935        let mut headers = HeaderMap::new();
936        headers.insert(
937            "host",
938            "s3.us-east-1.localhost.localstack.cloud:4566"
939                .parse()
940                .unwrap(),
941        );
942        let query = HashMap::new();
943        let body = Bytes::new();
944        let detected = detect_service(&headers, &query, &body).unwrap();
945        assert_eq!(detected.service, "s3");
946        assert_eq!(detected.protocol, AwsProtocol::Rest);
947    }
948
949    #[test]
950    fn detect_service_via_host_for_rest_json_service() {
951        let mut headers = HeaderMap::new();
952        headers.insert(
953            "host",
954            "lambda.us-east-1.localhost.localstack.cloud:4566"
955                .parse()
956                .unwrap(),
957        );
958        let query = HashMap::new();
959        let body = Bytes::new();
960        let detected = detect_service(&headers, &query, &body).unwrap();
961        assert_eq!(detected.service, "lambda");
962        assert_eq!(detected.protocol, AwsProtocol::RestJson);
963    }
964
965    #[test]
966    fn detect_service_via_host_plus_query_action() {
967        let mut headers = HeaderMap::new();
968        headers.insert(
969            "host",
970            "sqs.us-east-1.localhost.localstack.cloud:4566"
971                .parse()
972                .unwrap(),
973        );
974        let mut query = HashMap::new();
975        query.insert("Action".to_string(), "ListQueues".to_string());
976        let body = Bytes::new();
977        let detected = detect_service(&headers, &query, &body).unwrap();
978        assert_eq!(detected.service, "sqs");
979        assert_eq!(detected.action, "ListQueues");
980        assert_eq!(detected.protocol, AwsProtocol::Query);
981    }
982
983    #[test]
984    fn detect_service_sigv4_wins_over_host() {
985        let mut headers = HeaderMap::new();
986        headers.insert(
987            "authorization",
988            "AWS4-HMAC-SHA256 Credential=AKID/20240101/us-east-1/s3/aws4_request, \
989             SignedHeaders=host, Signature=abc"
990                .parse()
991                .unwrap(),
992        );
993        headers.insert(
994            "host",
995            "lambda.us-east-1.localhost.localstack.cloud:4566"
996                .parse()
997                .unwrap(),
998        );
999        let query = HashMap::new();
1000        let body = Bytes::new();
1001        let detected = detect_service(&headers, &query, &body).unwrap();
1002        // SigV4 credential scope says s3; Host header says lambda. SigV4 wins.
1003        assert_eq!(detected.service, "s3");
1004        assert_eq!(detected.protocol, AwsProtocol::Rest);
1005    }
1006
1007    #[test]
1008    fn detect_service_host_for_virtual_hosted_s3() {
1009        let mut headers = HeaderMap::new();
1010        headers.insert(
1011            "host",
1012            "my-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1013                .parse()
1014                .unwrap(),
1015        );
1016        let query = HashMap::new();
1017        let body = Bytes::new();
1018        let detected = detect_service(&headers, &query, &body).unwrap();
1019        assert_eq!(detected.service, "s3");
1020        assert_eq!(detected.protocol, AwsProtocol::Rest);
1021    }
1022}