use bytes::Bytes;
use http::HeaderMap;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AwsProtocol {
Query,
Json,
Rest,
RestJson,
}
const REST_XML_SERVICES: &[&str] = &["s3", "cloudfront", "route53"];
const REST_JSON_SERVICES: &[&str] = &[
"lambda",
"ses",
"apigateway",
"bedrock",
"bedrock-agent",
"bedrock-agent-runtime",
"scheduler",
];
#[derive(Debug, Clone)]
pub struct DetectedRequest {
pub service: String,
pub action: String,
pub protocol: AwsProtocol,
}
pub fn detect_service_headers_only(
headers: &HeaderMap,
query_params: &HashMap<String, String>,
) -> Option<DetectedRequest> {
if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
return parse_amz_target(target);
}
if let Some(action) = query_params.get("Action") {
let service = extract_service_from_auth(headers)
.or_else(|| infer_service_from_action(action))
.or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
if let Some(service) = service {
return Some(DetectedRequest {
service,
action: action.clone(),
protocol: AwsProtocol::Query,
});
}
}
if let Some(service) = extract_service_from_auth(headers) {
if let Some(protocol) = rest_protocol_for(&service) {
return Some(DetectedRequest {
service,
action: String::new(),
protocol,
});
}
}
if let Some(credential) = query_params.get("X-Amz-Credential") {
let parts: Vec<&str> = credential.split('/').collect();
if parts.len() >= 4 {
let service = normalize_service_name(parts[3]).to_string();
if let Some(protocol) = rest_protocol_for(&service) {
return Some(DetectedRequest {
service,
action: String::new(),
protocol,
});
}
}
}
if query_params.contains_key("AWSAccessKeyId")
&& query_params.contains_key("Signature")
&& query_params.contains_key("Expires")
{
return Some(DetectedRequest {
service: "s3".to_string(),
action: String::new(),
protocol: AwsProtocol::Rest,
});
}
if let Some(host_info) = parse_routing_host_from_headers(headers) {
if let Some(protocol) = rest_protocol_for(&host_info.service) {
return Some(DetectedRequest {
service: host_info.service,
action: String::new(),
protocol,
});
}
}
None
}
pub fn detect_service(
headers: &HeaderMap,
query_params: &HashMap<String, String>,
body: &Bytes,
) -> Option<DetectedRequest> {
if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
return parse_amz_target(target);
}
if let Some(action) = query_params.get("Action") {
let service = extract_service_from_auth(headers)
.or_else(|| infer_service_from_action(action))
.or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
if let Some(service) = service {
return Some(DetectedRequest {
service,
action: action.clone(),
protocol: AwsProtocol::Query,
});
}
}
{
let form_params = decode_form_urlencoded(body);
if let Some(action) = form_params.get("Action") {
let service = extract_service_from_auth(headers)
.or_else(|| infer_service_from_action(action))
.or_else(|| parse_routing_host_from_headers(headers).map(|h| h.service));
if let Some(service) = service {
return Some(DetectedRequest {
service,
action: action.clone(),
protocol: AwsProtocol::Query,
});
}
}
}
if let Some(service) = extract_service_from_auth(headers) {
if let Some(protocol) = rest_protocol_for(&service) {
return Some(DetectedRequest {
service,
action: String::new(), protocol,
});
}
}
if let Some(credential) = query_params.get("X-Amz-Credential") {
let parts: Vec<&str> = credential.split('/').collect();
if parts.len() >= 4 {
let service = normalize_service_name(parts[3]).to_string();
if let Some(protocol) = rest_protocol_for(&service) {
return Some(DetectedRequest {
service,
action: String::new(),
protocol,
});
}
}
}
if query_params.contains_key("AWSAccessKeyId")
&& query_params.contains_key("Signature")
&& query_params.contains_key("Expires")
{
return Some(DetectedRequest {
service: "s3".to_string(),
action: String::new(),
protocol: AwsProtocol::Rest,
});
}
if let Some(host_info) = parse_routing_host_from_headers(headers) {
if let Some(protocol) = rest_protocol_for(&host_info.service) {
return Some(DetectedRequest {
service: host_info.service,
action: String::new(),
protocol,
});
}
}
None
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RoutingHost {
pub service: String,
pub region: String,
pub bucket: Option<String>,
}
const LOCALSTACK_SUFFIX: &str = ".localhost.localstack.cloud";
const AWS_SUFFIX: &str = ".amazonaws.com";
pub fn parse_routing_host(host: &str) -> Option<RoutingHost> {
let hostname = host.split(':').next()?;
if hostname.is_empty() {
return None;
}
let hostname = hostname.to_ascii_lowercase();
if let Some(prefix) = hostname.strip_suffix(LOCALSTACK_SUFFIX) {
return parse_localstack_prefix(prefix);
}
if hostname == "amazonaws.com" {
return None;
}
if let Some(prefix) = hostname.strip_suffix(AWS_SUFFIX) {
return parse_aws_prefix(prefix);
}
None
}
pub fn parse_routing_host_from_headers(headers: &HeaderMap) -> Option<RoutingHost> {
let host = headers.get("host")?.to_str().ok()?;
parse_routing_host(host)
}
fn parse_localstack_prefix(prefix: &str) -> Option<RoutingHost> {
if prefix.is_empty() {
return None;
}
let labels: Vec<&str> = prefix.split('.').collect();
if labels.iter().any(|l| l.is_empty()) {
return None;
}
match labels.len() {
2 => Some(RoutingHost {
service: labels[0].to_string(),
region: labels[1].to_string(),
bucket: None,
}),
n if n >= 3 && labels[n - 2] == "s3" => {
let bucket = labels[..n - 2].join(".");
Some(RoutingHost {
service: "s3".to_string(),
region: labels[n - 1].to_string(),
bucket: Some(bucket),
})
}
n if n >= 3 && labels[n - 2] == "s3-accesspoint" => {
let bucket = labels[..n - 2].join(".");
Some(RoutingHost {
service: "s3".to_string(),
region: labels[n - 1].to_string(),
bucket: Some(bucket),
})
}
n if n >= 3 && labels[n - 2] == "s3-control" => Some(RoutingHost {
service: "s3".to_string(),
region: labels[n - 1].to_string(),
bucket: None,
}),
_ => None,
}
}
fn parse_aws_prefix(prefix: &str) -> Option<RoutingHost> {
if prefix.is_empty() {
return None;
}
let labels: Vec<&str> = prefix.split('.').collect();
if labels.iter().any(|l| l.is_empty()) {
return None;
}
let last = *labels.last()?;
if let Some(region) = last.strip_prefix("s3-") {
if !region.is_empty() {
let bucket = if labels.len() >= 2 {
Some(labels[..labels.len() - 1].join("."))
} else {
None
};
return Some(RoutingHost {
service: "s3".to_string(),
region: region.to_string(),
bucket,
});
}
}
if last == "s3" {
if labels.len() == 1 {
return Some(RoutingHost {
service: "s3".to_string(),
region: "us-east-1".to_string(),
bucket: None,
});
}
return Some(RoutingHost {
service: "s3".to_string(),
region: "us-east-1".to_string(),
bucket: Some(labels[..labels.len() - 1].join(".")),
});
}
if last == "s3-accesspoint" {
if labels.len() == 2 {
return Some(RoutingHost {
service: "s3".to_string(),
region: labels[0].to_string(),
bucket: None,
});
}
let bucket = labels[..labels.len() - 2].join(".");
return Some(RoutingHost {
service: "s3".to_string(),
region: labels[labels.len() - 1].to_string(),
bucket: Some(bucket),
});
}
if labels.len() >= 2 && labels[labels.len() - 2] == "s3-control" {
return Some(RoutingHost {
service: "s3".to_string(),
region: last.to_string(),
bucket: None,
});
}
match labels.len() {
2 => Some(RoutingHost {
service: labels[0].to_string(),
region: labels[1].to_string(),
bucket: None,
}),
n if n >= 3 && labels[n - 2] == "s3" => {
let bucket = labels[..n - 2].join(".");
Some(RoutingHost {
service: "s3".to_string(),
region: labels[n - 1].to_string(),
bucket: Some(bucket),
})
}
_ => None,
}
}
fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
let (prefix, action) = target.rsplit_once('.')?;
let service = match prefix {
"AWSEvents" => "events",
"AmazonSSM" => "ssm",
"AmazonSQS" => "sqs",
"AmazonSNS" => "sns",
"DynamoDB_20120810" => "dynamodb",
"DynamoDBStreams_20120810" => "dynamodbstreams",
"Logs_20140328" => "logs",
s if s.starts_with("secretsmanager") => "secretsmanager",
s if s.starts_with("TrentService") => "kms",
s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
s if s.starts_with("AWSCognitoIdentityService") => "cognito-identity",
s if s.starts_with("Kinesis_20131202") => "kinesis",
s if s.starts_with("AmazonEC2ContainerRegistry_V") => "ecr",
s if s.starts_with("AmazonEC2ContainerServiceV") => "ecs",
s if s.starts_with("AWSStepFunctions") => "states",
s if s.starts_with("AWSOrganizationsV") => "organizations",
"CertificateManager" => "acm",
"AnyScaleFrontendService" => "application-autoscaling",
"AWSWAF_20190729" => "wafv2",
"AmazonAthena" => "athena",
s if s.starts_with("Firehose_") => "firehose",
"AWSGlue" => "glue",
_ => return None,
};
Some(DetectedRequest {
service: service.to_string(),
action: action.to_string(),
protocol: AwsProtocol::Json,
})
}
fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
if REST_XML_SERVICES.contains(&service) {
Some(AwsProtocol::Rest)
} else if REST_JSON_SERVICES.contains(&service) {
Some(AwsProtocol::RestJson)
} else {
None
}
}
fn infer_service_from_action(action: &str) -> Option<String> {
match action {
"AssumeRole"
| "AssumeRoleWithSAML"
| "AssumeRoleWithWebIdentity"
| "GetCallerIdentity"
| "GetSessionToken"
| "GetFederationToken"
| "GetAccessKeyInfo"
| "DecodeAuthorizationMessage" => Some("sts".to_string()),
"CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
| "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
| "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
| "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
"VerifyEmailIdentity"
| "VerifyDomainIdentity"
| "VerifyDomainDkim"
| "ListIdentities"
| "GetIdentityVerificationAttributes"
| "GetIdentityDkimAttributes"
| "DeleteIdentity"
| "SetIdentityDkimEnabled"
| "SetIdentityNotificationTopic"
| "SetIdentityFeedbackForwardingEnabled"
| "GetIdentityNotificationAttributes"
| "GetIdentityMailFromDomainAttributes"
| "SetIdentityMailFromDomain"
| "SendEmail"
| "SendRawEmail"
| "SendTemplatedEmail"
| "SendBulkTemplatedEmail"
| "CreateTemplate"
| "GetTemplate"
| "ListTemplates"
| "DeleteTemplate"
| "UpdateTemplate"
| "CreateConfigurationSet"
| "DeleteConfigurationSet"
| "DescribeConfigurationSet"
| "ListConfigurationSets"
| "CreateConfigurationSetEventDestination"
| "UpdateConfigurationSetEventDestination"
| "DeleteConfigurationSetEventDestination"
| "GetSendQuota"
| "GetSendStatistics"
| "GetAccountSendingEnabled"
| "CreateReceiptRuleSet"
| "DeleteReceiptRuleSet"
| "DescribeReceiptRuleSet"
| "ListReceiptRuleSets"
| "CloneReceiptRuleSet"
| "SetActiveReceiptRuleSet"
| "ReorderReceiptRuleSet"
| "CreateReceiptRule"
| "DeleteReceiptRule"
| "DescribeReceiptRule"
| "UpdateReceiptRule"
| "CreateReceiptFilter"
| "DeleteReceiptFilter"
| "ListReceiptFilters" => Some("ses".to_string()),
_ => None,
}
}
fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
let auth = headers.get("authorization")?.to_str().ok()?;
let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
Some(normalize_service_name(&info.service).to_string())
}
fn normalize_service_name(service: &str) -> &str {
match service {
"bedrock-runtime" => "bedrock",
"apigatewayv2" => "apigateway",
other => other,
}
}
pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
decode_form_urlencoded(body)
}
fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
let s = std::str::from_utf8(input).unwrap_or("");
let mut result = HashMap::new();
for pair in s.split('&') {
if pair.is_empty() {
continue;
}
let (key, value) = match pair.find('=') {
Some(pos) => (&pair[..pos], &pair[pos + 1..]),
None => (pair, ""),
};
result.insert(url_decode(key), url_decode(value));
}
result
}
fn url_decode(input: &str) -> String {
let mut result = String::with_capacity(input.len());
let mut bytes = input.bytes();
while let Some(b) = bytes.next() {
match b {
b'+' => result.push(' '),
b'%' => {
let high = bytes.next().and_then(from_hex);
let low = bytes.next().and_then(from_hex);
if let (Some(h), Some(l)) = (high, low) {
result.push((h << 4 | l) as char);
}
}
_ => result.push(b as char),
}
}
result
}
fn from_hex(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_amz_target_events() {
let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
assert_eq!(result.service, "events");
assert_eq!(result.action, "PutEvents");
assert_eq!(result.protocol, AwsProtocol::Json);
}
#[test]
fn parse_amz_target_ssm() {
let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
assert_eq!(result.service, "ssm");
assert_eq!(result.action, "GetParameter");
}
#[test]
fn parse_amz_target_kinesis() {
let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
assert_eq!(result.service, "kinesis");
assert_eq!(result.action, "ListStreams");
assert_eq!(result.protocol, AwsProtocol::Json);
}
#[test]
fn parse_query_body_basic() {
let body = Bytes::from(
"Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
);
let params = parse_query_body(&body);
assert_eq!(params.get("Action").unwrap(), "SendMessage");
assert_eq!(params.get("MessageBody").unwrap(), "hello");
}
#[test]
fn parse_query_body_empty_returns_empty_map() {
let body = Bytes::from("");
let params = parse_query_body(&body);
assert!(params.is_empty());
}
#[test]
fn parse_query_body_duplicate_keys_last_wins() {
let body = Bytes::from("key=a&key=b");
let params = parse_query_body(&body);
assert_eq!(params.get("key").unwrap(), "b");
}
#[test]
fn parse_query_body_single_key() {
let body = Bytes::from("key=value");
let params = parse_query_body(&body);
assert_eq!(params.get("key").unwrap(), "value");
}
#[test]
fn parse_amz_target_ecs() {
let result = parse_amz_target("AmazonEC2ContainerServiceV20141113.ListClusters").unwrap();
assert_eq!(result.service, "ecs");
assert_eq!(result.action, "ListClusters");
assert_eq!(result.protocol, AwsProtocol::Json);
}
#[test]
fn parse_amz_target_invalid_returns_none() {
assert!(parse_amz_target("NoDotHere").is_none());
assert!(parse_amz_target("").is_none());
}
#[test]
fn parse_amz_target_various_prefixes() {
assert_eq!(
parse_amz_target("AmazonSQS.SendMessage").unwrap().service,
"sqs"
);
assert_eq!(
parse_amz_target("AmazonSNS.Publish").unwrap().service,
"sns"
);
assert_eq!(
parse_amz_target("DynamoDB_20120810.GetItem")
.unwrap()
.service,
"dynamodb"
);
assert_eq!(
parse_amz_target("Logs_20140328.PutLogEvents")
.unwrap()
.service,
"logs"
);
assert_eq!(
parse_amz_target("secretsmanager.GetSecretValue")
.unwrap()
.service,
"secretsmanager"
);
assert_eq!(
parse_amz_target("TrentService.Encrypt").unwrap().service,
"kms"
);
assert_eq!(
parse_amz_target("AWSCognitoIdentityProviderService.InitiateAuth")
.unwrap()
.service,
"cognito-idp"
);
assert_eq!(
parse_amz_target("AWSStepFunctions.StartExecution")
.unwrap()
.service,
"states"
);
assert_eq!(
parse_amz_target("AWSOrganizationsV20161128.CreateOrganization")
.unwrap()
.service,
"organizations"
);
assert!(parse_amz_target("UnknownServicePrefix.Action").is_none());
}
#[test]
fn infer_service_from_action_maps_sts() {
assert_eq!(
infer_service_from_action("AssumeRole").as_deref(),
Some("sts")
);
assert_eq!(
infer_service_from_action("GetCallerIdentity").as_deref(),
Some("sts")
);
}
#[test]
fn infer_service_from_action_maps_iam() {
assert_eq!(
infer_service_from_action("CreateUser").as_deref(),
Some("iam")
);
assert_eq!(
infer_service_from_action("ListRoles").as_deref(),
Some("iam")
);
}
#[test]
fn infer_service_from_action_maps_ses() {
assert_eq!(
infer_service_from_action("SendEmail").as_deref(),
Some("ses")
);
assert_eq!(
infer_service_from_action("ListIdentities").as_deref(),
Some("ses")
);
}
#[test]
fn infer_service_from_action_unknown_returns_none() {
assert!(infer_service_from_action("NotARealAction").is_none());
}
#[test]
fn rest_protocol_for_returns_none_for_non_rest_service() {
assert!(rest_protocol_for("sqs").is_none());
}
#[test]
fn url_decode_handles_percent_and_plus() {
assert_eq!(url_decode("hello+world"), "hello world");
assert_eq!(url_decode("hello%20world"), "hello world");
assert_eq!(url_decode("100%25"), "100%");
}
#[test]
fn url_decode_ignores_malformed_percent() {
assert_eq!(url_decode("%ZZ"), "");
}
#[test]
fn from_hex_valid_digits() {
assert_eq!(from_hex(b'0'), Some(0));
assert_eq!(from_hex(b'9'), Some(9));
assert_eq!(from_hex(b'a'), Some(10));
assert_eq!(from_hex(b'F'), Some(15));
}
#[test]
fn from_hex_invalid_returns_none() {
assert!(from_hex(b'g').is_none());
assert!(from_hex(b' ').is_none());
}
#[test]
fn detect_service_via_amz_target() {
let mut headers = HeaderMap::new();
headers.insert("x-amz-target", "AmazonSSM.GetParameter".parse().unwrap());
let query = HashMap::new();
let body = Bytes::new();
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "ssm");
assert_eq!(detected.action, "GetParameter");
}
#[test]
fn detect_service_via_query_action_with_inferred_service() {
let headers = HeaderMap::new();
let mut query = HashMap::new();
query.insert("Action".to_string(), "AssumeRole".to_string());
let body = Bytes::new();
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "sts");
assert_eq!(detected.action, "AssumeRole");
assert_eq!(detected.protocol, AwsProtocol::Query);
}
#[test]
fn detect_service_via_form_body() {
let headers = HeaderMap::new();
let query = HashMap::new();
let body = Bytes::from("Action=SendEmail&Source=x%40y.com");
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "ses");
assert_eq!(detected.action, "SendEmail");
}
#[test]
fn detect_service_via_sigv2_presigned() {
let headers = HeaderMap::new();
let mut query = HashMap::new();
query.insert("AWSAccessKeyId".to_string(), "AKID".to_string());
query.insert("Signature".to_string(), "sig".to_string());
query.insert("Expires".to_string(), "1234567890".to_string());
let body = Bytes::new();
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "s3");
assert_eq!(detected.protocol, AwsProtocol::Rest);
}
#[test]
fn detect_service_via_sigv4_presigned_credential() {
let headers = HeaderMap::new();
let mut query = HashMap::new();
query.insert(
"X-Amz-Credential".to_string(),
"AKID/20240101/us-east-1/s3/aws4_request".to_string(),
);
let body = Bytes::new();
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "s3");
assert_eq!(detected.protocol, AwsProtocol::Rest);
}
#[test]
fn detect_service_unknown_returns_none() {
let headers = HeaderMap::new();
let query = HashMap::new();
let body = Bytes::new();
assert!(detect_service(&headers, &query, &body).is_none());
}
#[test]
fn normalize_service_name_aliases_apigatewayv2_to_apigateway() {
assert_eq!(normalize_service_name("apigatewayv2"), "apigateway");
}
#[test]
fn normalize_service_name_aliases_bedrock_runtime_to_bedrock() {
assert_eq!(normalize_service_name("bedrock-runtime"), "bedrock");
}
#[test]
fn normalize_service_name_passes_through_unaliased_services() {
assert_eq!(normalize_service_name("bedrock"), "bedrock");
assert_eq!(normalize_service_name("s3"), "s3");
assert_eq!(normalize_service_name("lambda"), "lambda");
assert_eq!(normalize_service_name(""), "");
assert_eq!(
normalize_service_name("unknown-future-service"),
"unknown-future-service"
);
}
#[test]
fn detect_service_via_authorization_header_normalizes_bedrock_runtime() {
let mut headers = HeaderMap::new();
headers.insert(
"authorization",
"AWS4-HMAC-SHA256 \
Credential=AKID/20240101/us-east-1/bedrock-runtime/aws4_request, \
SignedHeaders=host, Signature=abc"
.parse()
.unwrap(),
);
let query = HashMap::new();
let body = Bytes::new();
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "bedrock");
assert_eq!(detected.protocol, AwsProtocol::RestJson);
}
#[test]
fn detect_service_via_sigv4_presigned_credential_normalizes_bedrock_runtime() {
let headers = HeaderMap::new();
let mut query = HashMap::new();
query.insert(
"X-Amz-Credential".to_string(),
"AKID/20240101/us-east-1/bedrock-runtime/aws4_request".to_string(),
);
let body = Bytes::new();
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "bedrock");
assert_eq!(detected.protocol, AwsProtocol::RestJson);
}
#[test]
fn parse_routing_host_localstack_basic() {
let h = parse_routing_host("sqs.us-east-1.localhost.localstack.cloud").unwrap();
assert_eq!(h.service, "sqs");
assert_eq!(h.region, "us-east-1");
assert!(h.bucket.is_none());
}
#[test]
fn parse_routing_host_localstack_with_port() {
let h = parse_routing_host("lambda.eu-west-1.localhost.localstack.cloud:4566").unwrap();
assert_eq!(h.service, "lambda");
assert_eq!(h.region, "eu-west-1");
assert!(h.bucket.is_none());
}
#[test]
fn parse_routing_host_case_insensitive() {
let h = parse_routing_host("SQS.US-EAST-1.LOCALHOST.LOCALSTACK.CLOUD:4566").unwrap();
assert_eq!(h.service, "sqs");
assert_eq!(h.region, "us-east-1");
let h = parse_routing_host("LAMBDA.US-EAST-1.AMAZONAWS.COM").unwrap();
assert_eq!(h.service, "lambda");
assert_eq!(h.region, "us-east-1");
}
#[test]
fn parse_routing_host_localstack_s3_virtual_hosted() {
let h =
parse_routing_host("my-bucket.s3.us-east-1.localhost.localstack.cloud:4566").unwrap();
assert_eq!(h.service, "s3");
assert_eq!(h.region, "us-east-1");
assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
}
#[test]
fn parse_routing_host_localstack_s3_vhost_bucket_with_dots() {
let h = parse_routing_host("a.b.c.s3.us-east-1.localhost.localstack.cloud").unwrap();
assert_eq!(h.service, "s3");
assert_eq!(h.region, "us-east-1");
assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
}
#[test]
fn parse_routing_host_aws_service_region() {
let h = parse_routing_host("sqs.us-east-1.amazonaws.com").unwrap();
assert_eq!(h.service, "sqs");
assert_eq!(h.region, "us-east-1");
assert!(h.bucket.is_none());
let h = parse_routing_host("dynamodb.eu-west-2.amazonaws.com:443").unwrap();
assert_eq!(h.service, "dynamodb");
assert_eq!(h.region, "eu-west-2");
}
#[test]
fn parse_routing_host_aws_s3_path_style_modern() {
let h = parse_routing_host("s3.us-east-1.amazonaws.com").unwrap();
assert_eq!(h.service, "s3");
assert_eq!(h.region, "us-east-1");
assert!(h.bucket.is_none());
}
#[test]
fn parse_routing_host_aws_s3_virtual_hosted_modern() {
let h = parse_routing_host("my-bucket.s3.us-east-1.amazonaws.com").unwrap();
assert_eq!(h.service, "s3");
assert_eq!(h.region, "us-east-1");
assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
}
#[test]
fn parse_routing_host_aws_s3_vhost_bucket_with_dots() {
let h = parse_routing_host("a.b.c.s3.us-east-1.amazonaws.com").unwrap();
assert_eq!(h.service, "s3");
assert_eq!(h.region, "us-east-1");
assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
}
#[test]
fn parse_routing_host_aws_s3_legacy_global() {
let h = parse_routing_host("s3.amazonaws.com").unwrap();
assert_eq!(h.service, "s3");
assert_eq!(h.region, "us-east-1");
assert!(h.bucket.is_none());
let h = parse_routing_host("my-bucket.s3.amazonaws.com").unwrap();
assert_eq!(h.service, "s3");
assert_eq!(h.region, "us-east-1");
assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
}
#[test]
fn parse_routing_host_aws_s3_legacy_global_dotted_bucket() {
let h = parse_routing_host("a.b.c.s3.amazonaws.com").unwrap();
assert_eq!(h.service, "s3");
assert_eq!(h.region, "us-east-1");
assert_eq!(h.bucket.as_deref(), Some("a.b.c"));
}
#[test]
fn parse_routing_host_aws_s3_dash_separated() {
let h = parse_routing_host("s3-us-west-2.amazonaws.com").unwrap();
assert_eq!(h.service, "s3");
assert_eq!(h.region, "us-west-2");
assert!(h.bucket.is_none());
let h = parse_routing_host("my-bucket.s3-us-west-2.amazonaws.com").unwrap();
assert_eq!(h.service, "s3");
assert_eq!(h.region, "us-west-2");
assert_eq!(h.bucket.as_deref(), Some("my-bucket"));
}
#[test]
fn parse_routing_host_rejects_plain_localhost() {
assert!(parse_routing_host("localhost:4566").is_none());
assert!(parse_routing_host("127.0.0.1:4566").is_none());
}
#[test]
fn parse_routing_host_rejects_unknown_suffix() {
assert!(parse_routing_host("sqs.us-east-1.example.com").is_none());
assert!(parse_routing_host("s3.us-east-1.aws").is_none());
}
#[test]
fn parse_routing_host_empty_and_malformed_rejected() {
assert!(parse_routing_host("").is_none());
assert!(parse_routing_host(".localhost.localstack.cloud").is_none());
assert!(parse_routing_host("..localhost.localstack.cloud").is_none());
assert!(parse_routing_host("sqs.localhost.localstack.cloud").is_none());
assert!(parse_routing_host("foo.bar.baz.localhost.localstack.cloud").is_none());
assert!(parse_routing_host(".amazonaws.com").is_none());
assert!(parse_routing_host("amazonaws.com").is_none());
}
#[test]
fn detect_service_via_host_for_rest_service() {
let mut headers = HeaderMap::new();
headers.insert(
"host",
"s3.us-east-1.localhost.localstack.cloud:4566"
.parse()
.unwrap(),
);
let query = HashMap::new();
let body = Bytes::new();
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "s3");
assert_eq!(detected.protocol, AwsProtocol::Rest);
}
#[test]
fn detect_service_via_host_for_rest_json_service() {
let mut headers = HeaderMap::new();
headers.insert(
"host",
"lambda.us-east-1.localhost.localstack.cloud:4566"
.parse()
.unwrap(),
);
let query = HashMap::new();
let body = Bytes::new();
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "lambda");
assert_eq!(detected.protocol, AwsProtocol::RestJson);
}
#[test]
fn detect_service_via_host_plus_query_action() {
let mut headers = HeaderMap::new();
headers.insert(
"host",
"sqs.us-east-1.localhost.localstack.cloud:4566"
.parse()
.unwrap(),
);
let mut query = HashMap::new();
query.insert("Action".to_string(), "ListQueues".to_string());
let body = Bytes::new();
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "sqs");
assert_eq!(detected.action, "ListQueues");
assert_eq!(detected.protocol, AwsProtocol::Query);
}
#[test]
fn detect_service_sigv4_wins_over_host() {
let mut headers = HeaderMap::new();
headers.insert(
"authorization",
"AWS4-HMAC-SHA256 Credential=AKID/20240101/us-east-1/s3/aws4_request, \
SignedHeaders=host, Signature=abc"
.parse()
.unwrap(),
);
headers.insert(
"host",
"lambda.us-east-1.localhost.localstack.cloud:4566"
.parse()
.unwrap(),
);
let query = HashMap::new();
let body = Bytes::new();
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "s3");
assert_eq!(detected.protocol, AwsProtocol::Rest);
}
#[test]
fn detect_service_host_for_virtual_hosted_s3() {
let mut headers = HeaderMap::new();
headers.insert(
"host",
"my-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
.parse()
.unwrap(),
);
let query = HashMap::new();
let body = Bytes::new();
let detected = detect_service(&headers, &query, &body).unwrap();
assert_eq!(detected.service, "s3");
assert_eq!(detected.protocol, AwsProtocol::Rest);
}
}