Skip to main content

fakecloud_core/
protocol.rs

1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5/// The wire protocol used by an AWS service.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8    /// Query protocol: form-encoded body, Action param, XML response.
9    /// Used by: SQS, SNS, IAM, STS.
10    Query,
11    /// JSON protocol: JSON body, X-Amz-Target header, JSON response.
12    /// Used by: SSM, EventBridge, DynamoDB, SecretsManager, KMS, CloudWatch Logs.
13    Json,
14    /// REST protocol: HTTP method + path-based routing, XML responses.
15    /// Used by: S3, API Gateway, Route53.
16    Rest,
17    /// REST-JSON protocol: HTTP method + path-based routing, JSON responses.
18    /// Used by: Lambda, SES v2.
19    RestJson,
20}
21
22/// Services that use REST protocol with XML responses (detected from SigV4 credential scope).
23const REST_XML_SERVICES: &[&str] = &["s3"];
24
25/// Services that use REST protocol with JSON responses (detected from SigV4 credential scope).
26const REST_JSON_SERVICES: &[&str] = &["lambda", "ses", "apigateway", "bedrock"];
27
28/// Detected service name and action from an incoming HTTP request.
29#[derive(Debug)]
30pub struct DetectedRequest {
31    pub service: String,
32    pub action: String,
33    pub protocol: AwsProtocol,
34}
35
36/// Detect the target service and action from HTTP request components.
37pub fn detect_service(
38    headers: &HeaderMap,
39    query_params: &HashMap<String, String>,
40    body: &Bytes,
41) -> Option<DetectedRequest> {
42    // 1. Check X-Amz-Target header (JSON protocol)
43    if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
44        return parse_amz_target(target);
45    }
46
47    // 2. Check for Query protocol (Action parameter in query string or form body)
48    if let Some(action) = query_params.get("Action") {
49        let service =
50            extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
51        if let Some(service) = service {
52            return Some(DetectedRequest {
53                service,
54                action: action.clone(),
55                protocol: AwsProtocol::Query,
56            });
57        }
58    }
59
60    // 3. Try form-encoded body
61    {
62        let form_params = decode_form_urlencoded(body);
63
64        if let Some(action) = form_params.get("Action") {
65            let service =
66                extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
67            if let Some(service) = service {
68                return Some(DetectedRequest {
69                    service,
70                    action: action.clone(),
71                    protocol: AwsProtocol::Query,
72                });
73            }
74        }
75    }
76
77    // 4. Fallback: check auth header for REST-style services (S3, Lambda, SES, etc.)
78    if let Some(service) = extract_service_from_auth(headers) {
79        if let Some(protocol) = rest_protocol_for(&service) {
80            return Some(DetectedRequest {
81                service,
82                action: String::new(), // REST services determine action from method+path
83                protocol,
84            });
85        }
86    }
87
88    // 5. Check query params for presigned URL auth (X-Amz-Credential for SigV4)
89    if let Some(credential) = query_params.get("X-Amz-Credential") {
90        // Format: AKID/date/region/service/aws4_request
91        let parts: Vec<&str> = credential.split('/').collect();
92        if parts.len() >= 4 {
93            let service = parts[3].to_string();
94            if let Some(protocol) = rest_protocol_for(&service) {
95                return Some(DetectedRequest {
96                    service,
97                    action: String::new(),
98                    protocol,
99                });
100            }
101        }
102    }
103
104    // 6. Check for SigV2-style presigned URL (AWSAccessKeyId + Signature + Expires)
105    //    Only match when all three SigV2 presigned-URL parameters are present so
106    //    we don't accidentally claim non-S3 requests.
107    if query_params.contains_key("AWSAccessKeyId")
108        && query_params.contains_key("Signature")
109        && query_params.contains_key("Expires")
110    {
111        return Some(DetectedRequest {
112            service: "s3".to_string(),
113            action: String::new(),
114            protocol: AwsProtocol::Rest,
115        });
116    }
117
118    None
119}
120
121/// Parse `X-Amz-Target: AWSEvents.PutEvents` -> service=events, action=PutEvents
122/// Parse `X-Amz-Target: AmazonSSM.GetParameter` -> service=ssm, action=GetParameter
123fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
124    let (prefix, action) = target.rsplit_once('.')?;
125
126    let service = match prefix {
127        "AWSEvents" => "events",
128        "AmazonSSM" => "ssm",
129        "AmazonSQS" => "sqs",
130        "AmazonSNS" => "sns",
131        "DynamoDB_20120810" => "dynamodb",
132        "Logs_20140328" => "logs",
133        s if s.starts_with("secretsmanager") => "secretsmanager",
134        s if s.starts_with("TrentService") => "kms",
135        s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
136        s if s.starts_with("Kinesis_20131202") => "kinesis",
137        s if s.starts_with("AWSStepFunctions") => "states",
138        _ => return None,
139    };
140
141    Some(DetectedRequest {
142        service: service.to_string(),
143        action: action.to_string(),
144        protocol: AwsProtocol::Json,
145    })
146}
147
148/// Returns the REST protocol variant for a service, or None if not a REST service.
149fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
150    if REST_XML_SERVICES.contains(&service) {
151        Some(AwsProtocol::Rest)
152    } else if REST_JSON_SERVICES.contains(&service) {
153        Some(AwsProtocol::RestJson)
154    } else {
155        None
156    }
157}
158
159/// Infer service from the action name when no SigV4 auth is present.
160/// Some AWS operations (e.g., AssumeRoleWithSAML, AssumeRoleWithWebIdentity)
161/// do not require authentication and won't have an Authorization header.
162fn infer_service_from_action(action: &str) -> Option<String> {
163    match action {
164        "AssumeRole"
165        | "AssumeRoleWithSAML"
166        | "AssumeRoleWithWebIdentity"
167        | "GetCallerIdentity"
168        | "GetSessionToken"
169        | "GetFederationToken"
170        | "GetAccessKeyInfo"
171        | "DecodeAuthorizationMessage" => Some("sts".to_string()),
172        "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
173        | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
174        | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
175        | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
176        // SES v1 (Query protocol)
177        "VerifyEmailIdentity"
178        | "VerifyDomainIdentity"
179        | "VerifyDomainDkim"
180        | "ListIdentities"
181        | "GetIdentityVerificationAttributes"
182        | "GetIdentityDkimAttributes"
183        | "DeleteIdentity"
184        | "SetIdentityDkimEnabled"
185        | "SetIdentityNotificationTopic"
186        | "SetIdentityFeedbackForwardingEnabled"
187        | "GetIdentityNotificationAttributes"
188        | "GetIdentityMailFromDomainAttributes"
189        | "SetIdentityMailFromDomain"
190        | "SendEmail"
191        | "SendRawEmail"
192        | "SendTemplatedEmail"
193        | "SendBulkTemplatedEmail"
194        | "CreateTemplate"
195        | "GetTemplate"
196        | "ListTemplates"
197        | "DeleteTemplate"
198        | "UpdateTemplate"
199        | "CreateConfigurationSet"
200        | "DeleteConfigurationSet"
201        | "DescribeConfigurationSet"
202        | "ListConfigurationSets"
203        | "CreateConfigurationSetEventDestination"
204        | "UpdateConfigurationSetEventDestination"
205        | "DeleteConfigurationSetEventDestination"
206        | "GetSendQuota"
207        | "GetSendStatistics"
208        | "GetAccountSendingEnabled"
209        | "CreateReceiptRuleSet"
210        | "DeleteReceiptRuleSet"
211        | "DescribeReceiptRuleSet"
212        | "ListReceiptRuleSets"
213        | "CloneReceiptRuleSet"
214        | "SetActiveReceiptRuleSet"
215        | "ReorderReceiptRuleSet"
216        | "CreateReceiptRule"
217        | "DeleteReceiptRule"
218        | "DescribeReceiptRule"
219        | "UpdateReceiptRule"
220        | "CreateReceiptFilter"
221        | "DeleteReceiptFilter"
222        | "ListReceiptFilters" => Some("ses".to_string()),
223        _ => None,
224    }
225}
226
227/// Extract service name from the SigV4 Authorization header credential scope.
228fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
229    let auth = headers.get("authorization")?.to_str().ok()?;
230    let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
231    Some(info.service)
232}
233
234/// Parse form-encoded body into key-value pairs.
235pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
236    decode_form_urlencoded(body)
237}
238
239fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
240    let s = std::str::from_utf8(input).unwrap_or("");
241    let mut result = HashMap::new();
242    for pair in s.split('&') {
243        if pair.is_empty() {
244            continue;
245        }
246        let (key, value) = match pair.find('=') {
247            Some(pos) => (&pair[..pos], &pair[pos + 1..]),
248            None => (pair, ""),
249        };
250        result.insert(url_decode(key), url_decode(value));
251    }
252    result
253}
254
255fn url_decode(input: &str) -> String {
256    let mut result = String::with_capacity(input.len());
257    let mut bytes = input.bytes();
258    while let Some(b) = bytes.next() {
259        match b {
260            b'+' => result.push(' '),
261            b'%' => {
262                let high = bytes.next().and_then(from_hex);
263                let low = bytes.next().and_then(from_hex);
264                if let (Some(h), Some(l)) = (high, low) {
265                    result.push((h << 4 | l) as char);
266                }
267            }
268            _ => result.push(b as char),
269        }
270    }
271    result
272}
273
274fn from_hex(b: u8) -> Option<u8> {
275    match b {
276        b'0'..=b'9' => Some(b - b'0'),
277        b'a'..=b'f' => Some(b - b'a' + 10),
278        b'A'..=b'F' => Some(b - b'A' + 10),
279        _ => None,
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn parse_amz_target_events() {
289        let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
290        assert_eq!(result.service, "events");
291        assert_eq!(result.action, "PutEvents");
292        assert_eq!(result.protocol, AwsProtocol::Json);
293    }
294
295    #[test]
296    fn parse_amz_target_ssm() {
297        let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
298        assert_eq!(result.service, "ssm");
299        assert_eq!(result.action, "GetParameter");
300    }
301
302    #[test]
303    fn parse_amz_target_kinesis() {
304        let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
305        assert_eq!(result.service, "kinesis");
306        assert_eq!(result.action, "ListStreams");
307        assert_eq!(result.protocol, AwsProtocol::Json);
308    }
309
310    #[test]
311    fn parse_query_body_basic() {
312        let body = Bytes::from(
313            "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
314        );
315        let params = parse_query_body(&body);
316        assert_eq!(params.get("Action").unwrap(), "SendMessage");
317        assert_eq!(params.get("MessageBody").unwrap(), "hello");
318    }
319}