Skip to main content

fakecloud_core/
protocol.rs

1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5/// The wire protocol used by an AWS service.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8    /// Query protocol: form-encoded body, Action param, XML response.
9    /// Used by: SQS, SNS, IAM, STS.
10    Query,
11    /// JSON protocol: JSON body, X-Amz-Target header, JSON response.
12    /// Used by: SSM, EventBridge, DynamoDB, SecretsManager, KMS, CloudWatch Logs.
13    Json,
14    /// REST protocol: HTTP method + path-based routing, XML responses.
15    /// Used by: S3, API Gateway, Route53.
16    Rest,
17    /// REST-JSON protocol: HTTP method + path-based routing, JSON responses.
18    /// Used by: Lambda, SES v2.
19    RestJson,
20}
21
22/// Services that use REST protocol with XML responses (detected from SigV4 credential scope).
23const REST_XML_SERVICES: &[&str] = &["s3"];
24
25/// Services that use REST protocol with JSON responses (detected from SigV4 credential scope).
26const REST_JSON_SERVICES: &[&str] = &["lambda", "ses"];
27
28/// Detected service name and action from an incoming HTTP request.
29#[derive(Debug)]
30pub struct DetectedRequest {
31    pub service: String,
32    pub action: String,
33    pub protocol: AwsProtocol,
34}
35
36/// Detect the target service and action from HTTP request components.
37pub fn detect_service(
38    headers: &HeaderMap,
39    query_params: &HashMap<String, String>,
40    body: &Bytes,
41) -> Option<DetectedRequest> {
42    // 1. Check X-Amz-Target header (JSON protocol)
43    if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
44        return parse_amz_target(target);
45    }
46
47    // 2. Check for Query protocol (Action parameter in query string or form body)
48    if let Some(action) = query_params.get("Action") {
49        let service =
50            extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
51        if let Some(service) = service {
52            return Some(DetectedRequest {
53                service,
54                action: action.clone(),
55                protocol: AwsProtocol::Query,
56            });
57        }
58    }
59
60    // 3. Try form-encoded body
61    {
62        let form_params = decode_form_urlencoded(body);
63
64        if let Some(action) = form_params.get("Action") {
65            let service =
66                extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
67            if let Some(service) = service {
68                return Some(DetectedRequest {
69                    service,
70                    action: action.clone(),
71                    protocol: AwsProtocol::Query,
72                });
73            }
74        }
75    }
76
77    // 4. Fallback: check auth header for REST-style services (S3, Lambda, SES, etc.)
78    if let Some(service) = extract_service_from_auth(headers) {
79        if let Some(protocol) = rest_protocol_for(&service) {
80            return Some(DetectedRequest {
81                service,
82                action: String::new(), // REST services determine action from method+path
83                protocol,
84            });
85        }
86    }
87
88    // 5. Check query params for presigned URL auth (X-Amz-Credential for SigV4)
89    if let Some(credential) = query_params.get("X-Amz-Credential") {
90        // Format: AKID/date/region/service/aws4_request
91        let parts: Vec<&str> = credential.split('/').collect();
92        if parts.len() >= 4 {
93            let service = parts[3].to_string();
94            if let Some(protocol) = rest_protocol_for(&service) {
95                return Some(DetectedRequest {
96                    service,
97                    action: String::new(),
98                    protocol,
99                });
100            }
101        }
102    }
103
104    // 6. Check for SigV2-style presigned URL (AWSAccessKeyId + Signature + Expires)
105    //    Only match when all three SigV2 presigned-URL parameters are present so
106    //    we don't accidentally claim non-S3 requests.
107    if query_params.contains_key("AWSAccessKeyId")
108        && query_params.contains_key("Signature")
109        && query_params.contains_key("Expires")
110    {
111        return Some(DetectedRequest {
112            service: "s3".to_string(),
113            action: String::new(),
114            protocol: AwsProtocol::Rest,
115        });
116    }
117
118    None
119}
120
121/// Parse `X-Amz-Target: AWSEvents.PutEvents` -> service=events, action=PutEvents
122/// Parse `X-Amz-Target: AmazonSSM.GetParameter` -> service=ssm, action=GetParameter
123fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
124    let (prefix, action) = target.rsplit_once('.')?;
125
126    let service = match prefix {
127        "AWSEvents" => "events",
128        "AmazonSSM" => "ssm",
129        "AmazonSQS" => "sqs",
130        "AmazonSNS" => "sns",
131        "DynamoDB_20120810" => "dynamodb",
132        "Logs_20140328" => "logs",
133        s if s.starts_with("secretsmanager") => "secretsmanager",
134        s if s.starts_with("TrentService") => "kms",
135        s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
136        _ => return None,
137    };
138
139    Some(DetectedRequest {
140        service: service.to_string(),
141        action: action.to_string(),
142        protocol: AwsProtocol::Json,
143    })
144}
145
146/// Returns the REST protocol variant for a service, or None if not a REST service.
147fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
148    if REST_XML_SERVICES.contains(&service) {
149        Some(AwsProtocol::Rest)
150    } else if REST_JSON_SERVICES.contains(&service) {
151        Some(AwsProtocol::RestJson)
152    } else {
153        None
154    }
155}
156
157/// Infer service from the action name when no SigV4 auth is present.
158/// Some AWS operations (e.g., AssumeRoleWithSAML, AssumeRoleWithWebIdentity)
159/// do not require authentication and won't have an Authorization header.
160fn infer_service_from_action(action: &str) -> Option<String> {
161    match action {
162        "AssumeRole"
163        | "AssumeRoleWithSAML"
164        | "AssumeRoleWithWebIdentity"
165        | "GetCallerIdentity"
166        | "GetSessionToken"
167        | "GetFederationToken"
168        | "GetAccessKeyInfo"
169        | "DecodeAuthorizationMessage" => Some("sts".to_string()),
170        "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
171        | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
172        | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
173        | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
174        // SES v1 (Query protocol)
175        "VerifyEmailIdentity"
176        | "VerifyDomainIdentity"
177        | "VerifyDomainDkim"
178        | "ListIdentities"
179        | "GetIdentityVerificationAttributes"
180        | "GetIdentityDkimAttributes"
181        | "DeleteIdentity"
182        | "SetIdentityDkimEnabled"
183        | "SetIdentityNotificationTopic"
184        | "SetIdentityFeedbackForwardingEnabled"
185        | "GetIdentityNotificationAttributes"
186        | "GetIdentityMailFromDomainAttributes"
187        | "SetIdentityMailFromDomain"
188        | "SendEmail"
189        | "SendRawEmail"
190        | "SendTemplatedEmail"
191        | "SendBulkTemplatedEmail"
192        | "CreateTemplate"
193        | "GetTemplate"
194        | "ListTemplates"
195        | "DeleteTemplate"
196        | "UpdateTemplate"
197        | "CreateConfigurationSet"
198        | "DeleteConfigurationSet"
199        | "DescribeConfigurationSet"
200        | "ListConfigurationSets"
201        | "CreateConfigurationSetEventDestination"
202        | "UpdateConfigurationSetEventDestination"
203        | "DeleteConfigurationSetEventDestination"
204        | "GetSendQuota"
205        | "GetSendStatistics"
206        | "GetAccountSendingEnabled"
207        | "CreateReceiptRuleSet"
208        | "DeleteReceiptRuleSet"
209        | "DescribeReceiptRuleSet"
210        | "ListReceiptRuleSets"
211        | "CloneReceiptRuleSet"
212        | "SetActiveReceiptRuleSet"
213        | "ReorderReceiptRuleSet"
214        | "CreateReceiptRule"
215        | "DeleteReceiptRule"
216        | "DescribeReceiptRule"
217        | "UpdateReceiptRule"
218        | "CreateReceiptFilter"
219        | "DeleteReceiptFilter"
220        | "ListReceiptFilters" => Some("ses".to_string()),
221        _ => None,
222    }
223}
224
225/// Extract service name from the SigV4 Authorization header credential scope.
226fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
227    let auth = headers.get("authorization")?.to_str().ok()?;
228    let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
229    Some(info.service)
230}
231
232/// Parse form-encoded body into key-value pairs.
233pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
234    decode_form_urlencoded(body)
235}
236
237fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
238    let s = std::str::from_utf8(input).unwrap_or("");
239    let mut result = HashMap::new();
240    for pair in s.split('&') {
241        if pair.is_empty() {
242            continue;
243        }
244        let (key, value) = match pair.find('=') {
245            Some(pos) => (&pair[..pos], &pair[pos + 1..]),
246            None => (pair, ""),
247        };
248        result.insert(url_decode(key), url_decode(value));
249    }
250    result
251}
252
253fn url_decode(input: &str) -> String {
254    let mut result = String::with_capacity(input.len());
255    let mut bytes = input.bytes();
256    while let Some(b) = bytes.next() {
257        match b {
258            b'+' => result.push(' '),
259            b'%' => {
260                let high = bytes.next().and_then(from_hex);
261                let low = bytes.next().and_then(from_hex);
262                if let (Some(h), Some(l)) = (high, low) {
263                    result.push((h << 4 | l) as char);
264                }
265            }
266            _ => result.push(b as char),
267        }
268    }
269    result
270}
271
272fn from_hex(b: u8) -> Option<u8> {
273    match b {
274        b'0'..=b'9' => Some(b - b'0'),
275        b'a'..=b'f' => Some(b - b'a' + 10),
276        b'A'..=b'F' => Some(b - b'A' + 10),
277        _ => None,
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn parse_amz_target_events() {
287        let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
288        assert_eq!(result.service, "events");
289        assert_eq!(result.action, "PutEvents");
290        assert_eq!(result.protocol, AwsProtocol::Json);
291    }
292
293    #[test]
294    fn parse_amz_target_ssm() {
295        let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
296        assert_eq!(result.service, "ssm");
297        assert_eq!(result.action, "GetParameter");
298    }
299
300    #[test]
301    fn parse_query_body_basic() {
302        let body = Bytes::from(
303            "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
304        );
305        let params = parse_query_body(&body);
306        assert_eq!(params.get("Action").unwrap(), "SendMessage");
307        assert_eq!(params.get("MessageBody").unwrap(), "hello");
308    }
309}