Skip to main content

fakecloud_core/
protocol.rs

1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5/// The wire protocol used by an AWS service.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8    /// Query protocol: form-encoded body, Action param, XML response.
9    /// Used by: SQS, SNS, IAM, STS, SES.
10    Query,
11    /// JSON protocol: JSON body, X-Amz-Target header, JSON response.
12    /// Used by: SSM, EventBridge, DynamoDB, SecretsManager, KMS, CloudWatch Logs.
13    Json,
14    /// REST protocol: HTTP method + path-based routing, XML responses.
15    /// Used by: S3, API Gateway, Route53.
16    Rest,
17}
18
19/// Services that use REST protocol (detected from SigV4 credential scope).
20/// Lambda uses REST-style routing (method + path) rather than X-Amz-Target.
21const REST_SERVICES: &[&str] = &["s3", "lambda"];
22
23/// Detected service name and action from an incoming HTTP request.
24#[derive(Debug)]
25pub struct DetectedRequest {
26    pub service: String,
27    pub action: String,
28    pub protocol: AwsProtocol,
29}
30
31/// Detect the target service and action from HTTP request components.
32pub fn detect_service(
33    headers: &HeaderMap,
34    query_params: &HashMap<String, String>,
35    body: &Bytes,
36) -> Option<DetectedRequest> {
37    // 1. Check X-Amz-Target header (JSON protocol)
38    if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
39        return parse_amz_target(target);
40    }
41
42    // 2. Check for Query protocol (Action parameter in query string or form body)
43    if let Some(action) = query_params.get("Action") {
44        let service =
45            extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
46        if let Some(service) = service {
47            return Some(DetectedRequest {
48                service,
49                action: action.clone(),
50                protocol: AwsProtocol::Query,
51            });
52        }
53    }
54
55    // 3. Try form-encoded body
56    {
57        let form_params = decode_form_urlencoded(body);
58
59        if let Some(action) = form_params.get("Action") {
60            let service =
61                extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
62            if let Some(service) = service {
63                return Some(DetectedRequest {
64                    service,
65                    action: action.clone(),
66                    protocol: AwsProtocol::Query,
67                });
68            }
69        }
70    }
71
72    // 4. Fallback: check auth header for REST-style services (S3, etc.)
73    if let Some(service) = extract_service_from_auth(headers) {
74        if REST_SERVICES.contains(&service.as_str()) {
75            return Some(DetectedRequest {
76                service,
77                action: String::new(), // REST services determine action from method+path
78                protocol: AwsProtocol::Rest,
79            });
80        }
81    }
82
83    // 5. Check query params for presigned URL auth (X-Amz-Credential for SigV4)
84    if let Some(credential) = query_params.get("X-Amz-Credential") {
85        // Format: AKID/date/region/service/aws4_request
86        let parts: Vec<&str> = credential.split('/').collect();
87        if parts.len() >= 4 {
88            let service = parts[3].to_string();
89            if REST_SERVICES.contains(&service.as_str()) {
90                return Some(DetectedRequest {
91                    service,
92                    action: String::new(),
93                    protocol: AwsProtocol::Rest,
94                });
95            }
96        }
97    }
98
99    // 6. Check for SigV2-style presigned URL (AWSAccessKeyId + Signature + Expires)
100    //    Only match when all three SigV2 presigned-URL parameters are present so
101    //    we don't accidentally claim non-S3 requests.
102    if query_params.contains_key("AWSAccessKeyId")
103        && query_params.contains_key("Signature")
104        && query_params.contains_key("Expires")
105    {
106        return Some(DetectedRequest {
107            service: "s3".to_string(),
108            action: String::new(),
109            protocol: AwsProtocol::Rest,
110        });
111    }
112
113    None
114}
115
116/// Parse `X-Amz-Target: AWSEvents.PutEvents` -> service=events, action=PutEvents
117/// Parse `X-Amz-Target: AmazonSSM.GetParameter` -> service=ssm, action=GetParameter
118fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
119    let (prefix, action) = target.rsplit_once('.')?;
120
121    let service = match prefix {
122        "AWSEvents" => "events",
123        "AmazonSSM" => "ssm",
124        "AmazonSQS" => "sqs",
125        "AmazonSNS" => "sns",
126        "DynamoDB_20120810" => "dynamodb",
127        "Logs_20140328" => "logs",
128        s if s.starts_with("secretsmanager") => "secretsmanager",
129        s if s.starts_with("TrentService") => "kms",
130        _ => return None,
131    };
132
133    Some(DetectedRequest {
134        service: service.to_string(),
135        action: action.to_string(),
136        protocol: AwsProtocol::Json,
137    })
138}
139
140/// Infer service from the action name when no SigV4 auth is present.
141/// Some AWS operations (e.g., AssumeRoleWithSAML, AssumeRoleWithWebIdentity)
142/// do not require authentication and won't have an Authorization header.
143fn infer_service_from_action(action: &str) -> Option<String> {
144    match action {
145        "AssumeRole"
146        | "AssumeRoleWithSAML"
147        | "AssumeRoleWithWebIdentity"
148        | "GetCallerIdentity"
149        | "GetSessionToken"
150        | "GetFederationToken"
151        | "GetAccessKeyInfo"
152        | "DecodeAuthorizationMessage" => Some("sts".to_string()),
153        "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
154        | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
155        | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
156        | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
157        _ => None,
158    }
159}
160
161/// Extract service name from the SigV4 Authorization header credential scope.
162fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
163    let auth = headers.get("authorization")?.to_str().ok()?;
164    let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
165    Some(info.service)
166}
167
168/// Parse form-encoded body into key-value pairs.
169pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
170    decode_form_urlencoded(body)
171}
172
173fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
174    let s = std::str::from_utf8(input).unwrap_or("");
175    let mut result = HashMap::new();
176    for pair in s.split('&') {
177        if pair.is_empty() {
178            continue;
179        }
180        let (key, value) = match pair.find('=') {
181            Some(pos) => (&pair[..pos], &pair[pos + 1..]),
182            None => (pair, ""),
183        };
184        result.insert(url_decode(key), url_decode(value));
185    }
186    result
187}
188
189fn url_decode(input: &str) -> String {
190    let mut result = String::with_capacity(input.len());
191    let mut bytes = input.bytes();
192    while let Some(b) = bytes.next() {
193        match b {
194            b'+' => result.push(' '),
195            b'%' => {
196                let high = bytes.next().and_then(from_hex);
197                let low = bytes.next().and_then(from_hex);
198                if let (Some(h), Some(l)) = (high, low) {
199                    result.push((h << 4 | l) as char);
200                }
201            }
202            _ => result.push(b as char),
203        }
204    }
205    result
206}
207
208fn from_hex(b: u8) -> Option<u8> {
209    match b {
210        b'0'..=b'9' => Some(b - b'0'),
211        b'a'..=b'f' => Some(b - b'a' + 10),
212        b'A'..=b'F' => Some(b - b'A' + 10),
213        _ => None,
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn parse_amz_target_events() {
223        let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
224        assert_eq!(result.service, "events");
225        assert_eq!(result.action, "PutEvents");
226        assert_eq!(result.protocol, AwsProtocol::Json);
227    }
228
229    #[test]
230    fn parse_amz_target_ssm() {
231        let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
232        assert_eq!(result.service, "ssm");
233        assert_eq!(result.action, "GetParameter");
234    }
235
236    #[test]
237    fn parse_query_body_basic() {
238        let body = Bytes::from(
239            "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
240        );
241        let params = parse_query_body(&body);
242        assert_eq!(params.get("Action").unwrap(), "SendMessage");
243        assert_eq!(params.get("MessageBody").unwrap(), "hello");
244    }
245}