Skip to main content

fakecloud_core/
protocol.rs

1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5/// The wire protocol used by an AWS service.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8    /// Query protocol: form-encoded body, Action param, XML response.
9    /// Used by: SQS, SNS, IAM, STS.
10    Query,
11    /// JSON protocol: JSON body, X-Amz-Target header, JSON response.
12    /// Used by: SSM, EventBridge, DynamoDB, SecretsManager, KMS, CloudWatch Logs.
13    Json,
14    /// REST protocol: HTTP method + path-based routing, XML responses.
15    /// Used by: S3, API Gateway, Route53.
16    Rest,
17    /// REST-JSON protocol: HTTP method + path-based routing, JSON responses.
18    /// Used by: Lambda, SES v2.
19    RestJson,
20}
21
22/// Services that use REST protocol with XML responses (detected from SigV4 credential scope).
23const REST_XML_SERVICES: &[&str] = &["s3"];
24
25/// Services that use REST protocol with JSON responses (detected from SigV4 credential scope).
26const REST_JSON_SERVICES: &[&str] = &["lambda", "ses"];
27
28/// Detected service name and action from an incoming HTTP request.
29#[derive(Debug)]
30pub struct DetectedRequest {
31    pub service: String,
32    pub action: String,
33    pub protocol: AwsProtocol,
34}
35
36/// Detect the target service and action from HTTP request components.
37pub fn detect_service(
38    headers: &HeaderMap,
39    query_params: &HashMap<String, String>,
40    body: &Bytes,
41) -> Option<DetectedRequest> {
42    // 1. Check X-Amz-Target header (JSON protocol)
43    if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
44        return parse_amz_target(target);
45    }
46
47    // 2. Check for Query protocol (Action parameter in query string or form body)
48    if let Some(action) = query_params.get("Action") {
49        let service =
50            extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
51        if let Some(service) = service {
52            return Some(DetectedRequest {
53                service,
54                action: action.clone(),
55                protocol: AwsProtocol::Query,
56            });
57        }
58    }
59
60    // 3. Try form-encoded body
61    {
62        let form_params = decode_form_urlencoded(body);
63
64        if let Some(action) = form_params.get("Action") {
65            let service =
66                extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
67            if let Some(service) = service {
68                return Some(DetectedRequest {
69                    service,
70                    action: action.clone(),
71                    protocol: AwsProtocol::Query,
72                });
73            }
74        }
75    }
76
77    // 4. Fallback: check auth header for REST-style services (S3, Lambda, SES, etc.)
78    if let Some(service) = extract_service_from_auth(headers) {
79        if let Some(protocol) = rest_protocol_for(&service) {
80            return Some(DetectedRequest {
81                service,
82                action: String::new(), // REST services determine action from method+path
83                protocol,
84            });
85        }
86    }
87
88    // 5. Check query params for presigned URL auth (X-Amz-Credential for SigV4)
89    if let Some(credential) = query_params.get("X-Amz-Credential") {
90        // Format: AKID/date/region/service/aws4_request
91        let parts: Vec<&str> = credential.split('/').collect();
92        if parts.len() >= 4 {
93            let service = parts[3].to_string();
94            if let Some(protocol) = rest_protocol_for(&service) {
95                return Some(DetectedRequest {
96                    service,
97                    action: String::new(),
98                    protocol,
99                });
100            }
101        }
102    }
103
104    // 6. Check for SigV2-style presigned URL (AWSAccessKeyId + Signature + Expires)
105    //    Only match when all three SigV2 presigned-URL parameters are present so
106    //    we don't accidentally claim non-S3 requests.
107    if query_params.contains_key("AWSAccessKeyId")
108        && query_params.contains_key("Signature")
109        && query_params.contains_key("Expires")
110    {
111        return Some(DetectedRequest {
112            service: "s3".to_string(),
113            action: String::new(),
114            protocol: AwsProtocol::Rest,
115        });
116    }
117
118    None
119}
120
121/// Parse `X-Amz-Target: AWSEvents.PutEvents` -> service=events, action=PutEvents
122/// Parse `X-Amz-Target: AmazonSSM.GetParameter` -> service=ssm, action=GetParameter
123fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
124    let (prefix, action) = target.rsplit_once('.')?;
125
126    let service = match prefix {
127        "AWSEvents" => "events",
128        "AmazonSSM" => "ssm",
129        "AmazonSQS" => "sqs",
130        "AmazonSNS" => "sns",
131        "DynamoDB_20120810" => "dynamodb",
132        "Logs_20140328" => "logs",
133        s if s.starts_with("secretsmanager") => "secretsmanager",
134        s if s.starts_with("TrentService") => "kms",
135        s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
136        s if s.starts_with("Kinesis_20131202") => "kinesis",
137        _ => return None,
138    };
139
140    Some(DetectedRequest {
141        service: service.to_string(),
142        action: action.to_string(),
143        protocol: AwsProtocol::Json,
144    })
145}
146
147/// Returns the REST protocol variant for a service, or None if not a REST service.
148fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
149    if REST_XML_SERVICES.contains(&service) {
150        Some(AwsProtocol::Rest)
151    } else if REST_JSON_SERVICES.contains(&service) {
152        Some(AwsProtocol::RestJson)
153    } else {
154        None
155    }
156}
157
158/// Infer service from the action name when no SigV4 auth is present.
159/// Some AWS operations (e.g., AssumeRoleWithSAML, AssumeRoleWithWebIdentity)
160/// do not require authentication and won't have an Authorization header.
161fn infer_service_from_action(action: &str) -> Option<String> {
162    match action {
163        "AssumeRole"
164        | "AssumeRoleWithSAML"
165        | "AssumeRoleWithWebIdentity"
166        | "GetCallerIdentity"
167        | "GetSessionToken"
168        | "GetFederationToken"
169        | "GetAccessKeyInfo"
170        | "DecodeAuthorizationMessage" => Some("sts".to_string()),
171        "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
172        | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
173        | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
174        | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
175        // SES v1 (Query protocol)
176        "VerifyEmailIdentity"
177        | "VerifyDomainIdentity"
178        | "VerifyDomainDkim"
179        | "ListIdentities"
180        | "GetIdentityVerificationAttributes"
181        | "GetIdentityDkimAttributes"
182        | "DeleteIdentity"
183        | "SetIdentityDkimEnabled"
184        | "SetIdentityNotificationTopic"
185        | "SetIdentityFeedbackForwardingEnabled"
186        | "GetIdentityNotificationAttributes"
187        | "GetIdentityMailFromDomainAttributes"
188        | "SetIdentityMailFromDomain"
189        | "SendEmail"
190        | "SendRawEmail"
191        | "SendTemplatedEmail"
192        | "SendBulkTemplatedEmail"
193        | "CreateTemplate"
194        | "GetTemplate"
195        | "ListTemplates"
196        | "DeleteTemplate"
197        | "UpdateTemplate"
198        | "CreateConfigurationSet"
199        | "DeleteConfigurationSet"
200        | "DescribeConfigurationSet"
201        | "ListConfigurationSets"
202        | "CreateConfigurationSetEventDestination"
203        | "UpdateConfigurationSetEventDestination"
204        | "DeleteConfigurationSetEventDestination"
205        | "GetSendQuota"
206        | "GetSendStatistics"
207        | "GetAccountSendingEnabled"
208        | "CreateReceiptRuleSet"
209        | "DeleteReceiptRuleSet"
210        | "DescribeReceiptRuleSet"
211        | "ListReceiptRuleSets"
212        | "CloneReceiptRuleSet"
213        | "SetActiveReceiptRuleSet"
214        | "ReorderReceiptRuleSet"
215        | "CreateReceiptRule"
216        | "DeleteReceiptRule"
217        | "DescribeReceiptRule"
218        | "UpdateReceiptRule"
219        | "CreateReceiptFilter"
220        | "DeleteReceiptFilter"
221        | "ListReceiptFilters" => Some("ses".to_string()),
222        _ => None,
223    }
224}
225
226/// Extract service name from the SigV4 Authorization header credential scope.
227fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
228    let auth = headers.get("authorization")?.to_str().ok()?;
229    let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
230    Some(info.service)
231}
232
233/// Parse form-encoded body into key-value pairs.
234pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
235    decode_form_urlencoded(body)
236}
237
238fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
239    let s = std::str::from_utf8(input).unwrap_or("");
240    let mut result = HashMap::new();
241    for pair in s.split('&') {
242        if pair.is_empty() {
243            continue;
244        }
245        let (key, value) = match pair.find('=') {
246            Some(pos) => (&pair[..pos], &pair[pos + 1..]),
247            None => (pair, ""),
248        };
249        result.insert(url_decode(key), url_decode(value));
250    }
251    result
252}
253
254fn url_decode(input: &str) -> String {
255    let mut result = String::with_capacity(input.len());
256    let mut bytes = input.bytes();
257    while let Some(b) = bytes.next() {
258        match b {
259            b'+' => result.push(' '),
260            b'%' => {
261                let high = bytes.next().and_then(from_hex);
262                let low = bytes.next().and_then(from_hex);
263                if let (Some(h), Some(l)) = (high, low) {
264                    result.push((h << 4 | l) as char);
265                }
266            }
267            _ => result.push(b as char),
268        }
269    }
270    result
271}
272
273fn from_hex(b: u8) -> Option<u8> {
274    match b {
275        b'0'..=b'9' => Some(b - b'0'),
276        b'a'..=b'f' => Some(b - b'a' + 10),
277        b'A'..=b'F' => Some(b - b'A' + 10),
278        _ => None,
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn parse_amz_target_events() {
288        let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
289        assert_eq!(result.service, "events");
290        assert_eq!(result.action, "PutEvents");
291        assert_eq!(result.protocol, AwsProtocol::Json);
292    }
293
294    #[test]
295    fn parse_amz_target_ssm() {
296        let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
297        assert_eq!(result.service, "ssm");
298        assert_eq!(result.action, "GetParameter");
299    }
300
301    #[test]
302    fn parse_amz_target_kinesis() {
303        let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
304        assert_eq!(result.service, "kinesis");
305        assert_eq!(result.action, "ListStreams");
306        assert_eq!(result.protocol, AwsProtocol::Json);
307    }
308
309    #[test]
310    fn parse_query_body_basic() {
311        let body = Bytes::from(
312            "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
313        );
314        let params = parse_query_body(&body);
315        assert_eq!(params.get("Action").unwrap(), "SendMessage");
316        assert_eq!(params.get("MessageBody").unwrap(), "hello");
317    }
318}