Skip to main content

fakecloud_core/
protocol.rs

1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5/// The wire protocol used by an AWS service.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8    /// Query protocol: form-encoded body, Action param, XML response.
9    /// Used by: SQS, SNS, IAM, STS, SES.
10    Query,
11    /// JSON protocol: JSON body, X-Amz-Target header, JSON response.
12    /// Used by: SSM, EventBridge, DynamoDB, SecretsManager, KMS, CloudWatch Logs.
13    Json,
14    /// REST protocol: HTTP method + path-based routing, XML responses.
15    /// Used by: S3, API Gateway, Route53.
16    Rest,
17}
18
19/// Services that use REST protocol (detected from SigV4 credential scope).
20const REST_SERVICES: &[&str] = &["s3"];
21
22/// Detected service name and action from an incoming HTTP request.
23#[derive(Debug)]
24pub struct DetectedRequest {
25    pub service: String,
26    pub action: String,
27    pub protocol: AwsProtocol,
28}
29
30/// Detect the target service and action from HTTP request components.
31pub fn detect_service(
32    headers: &HeaderMap,
33    query_params: &HashMap<String, String>,
34    body: &Bytes,
35) -> Option<DetectedRequest> {
36    // 1. Check X-Amz-Target header (JSON protocol)
37    if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
38        return parse_amz_target(target);
39    }
40
41    // 2. Check for Query protocol (Action parameter in query string or form body)
42    if let Some(action) = query_params.get("Action") {
43        let service =
44            extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
45        if let Some(service) = service {
46            return Some(DetectedRequest {
47                service,
48                action: action.clone(),
49                protocol: AwsProtocol::Query,
50            });
51        }
52    }
53
54    // 3. Try form-encoded body
55    {
56        let form_params = decode_form_urlencoded(body);
57
58        if let Some(action) = form_params.get("Action") {
59            let service =
60                extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
61            if let Some(service) = service {
62                return Some(DetectedRequest {
63                    service,
64                    action: action.clone(),
65                    protocol: AwsProtocol::Query,
66                });
67            }
68        }
69    }
70
71    // 4. Fallback: check auth header for REST-style services (S3, etc.)
72    if let Some(service) = extract_service_from_auth(headers) {
73        if REST_SERVICES.contains(&service.as_str()) {
74            return Some(DetectedRequest {
75                service,
76                action: String::new(), // REST services determine action from method+path
77                protocol: AwsProtocol::Rest,
78            });
79        }
80    }
81
82    // 5. Check query params for presigned URL auth (X-Amz-Credential for SigV4)
83    if let Some(credential) = query_params.get("X-Amz-Credential") {
84        // Format: AKID/date/region/service/aws4_request
85        let parts: Vec<&str> = credential.split('/').collect();
86        if parts.len() >= 4 {
87            let service = parts[3].to_string();
88            if REST_SERVICES.contains(&service.as_str()) {
89                return Some(DetectedRequest {
90                    service,
91                    action: String::new(),
92                    protocol: AwsProtocol::Rest,
93                });
94            }
95        }
96    }
97
98    // 6. Check for SigV2-style presigned URL (AWSAccessKeyId + Signature)
99    if query_params.contains_key("AWSAccessKeyId") && query_params.contains_key("Signature") {
100        // SigV2 presigned URLs are typically for S3
101        return Some(DetectedRequest {
102            service: "s3".to_string(),
103            action: String::new(),
104            protocol: AwsProtocol::Rest,
105        });
106    }
107
108    None
109}
110
111/// Parse `X-Amz-Target: AWSEvents.PutEvents` -> service=events, action=PutEvents
112/// Parse `X-Amz-Target: AmazonSSM.GetParameter` -> service=ssm, action=GetParameter
113fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
114    let (prefix, action) = target.rsplit_once('.')?;
115
116    let service = match prefix {
117        "AWSEvents" => "events",
118        "AmazonSSM" => "ssm",
119        "AmazonSQS" => "sqs",
120        "AmazonSNS" => "sns",
121        "DynamoDB_20120810" => "dynamodb",
122        "Logs_20140328" => "logs",
123        s if s.starts_with("secretsmanager") => "secretsmanager",
124        s if s.starts_with("TrentService") => "kms",
125        _ => return None,
126    };
127
128    Some(DetectedRequest {
129        service: service.to_string(),
130        action: action.to_string(),
131        protocol: AwsProtocol::Json,
132    })
133}
134
135/// Infer service from the action name when no SigV4 auth is present.
136/// Some AWS operations (e.g., AssumeRoleWithSAML, AssumeRoleWithWebIdentity)
137/// do not require authentication and won't have an Authorization header.
138fn infer_service_from_action(action: &str) -> Option<String> {
139    match action {
140        "AssumeRole"
141        | "AssumeRoleWithSAML"
142        | "AssumeRoleWithWebIdentity"
143        | "GetCallerIdentity"
144        | "GetSessionToken"
145        | "GetFederationToken"
146        | "GetAccessKeyInfo"
147        | "DecodeAuthorizationMessage" => Some("sts".to_string()),
148        "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
149        | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
150        | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
151        | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
152        _ => None,
153    }
154}
155
156/// Extract service name from the SigV4 Authorization header credential scope.
157fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
158    let auth = headers.get("authorization")?.to_str().ok()?;
159    let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
160    Some(info.service)
161}
162
163/// Parse form-encoded body into key-value pairs.
164pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
165    decode_form_urlencoded(body)
166}
167
168fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
169    let s = std::str::from_utf8(input).unwrap_or("");
170    let mut result = HashMap::new();
171    for pair in s.split('&') {
172        if pair.is_empty() {
173            continue;
174        }
175        let (key, value) = match pair.find('=') {
176            Some(pos) => (&pair[..pos], &pair[pos + 1..]),
177            None => (pair, ""),
178        };
179        result.insert(url_decode(key), url_decode(value));
180    }
181    result
182}
183
184fn url_decode(input: &str) -> String {
185    let mut result = String::with_capacity(input.len());
186    let mut bytes = input.bytes();
187    while let Some(b) = bytes.next() {
188        match b {
189            b'+' => result.push(' '),
190            b'%' => {
191                let high = bytes.next().and_then(from_hex);
192                let low = bytes.next().and_then(from_hex);
193                if let (Some(h), Some(l)) = (high, low) {
194                    result.push((h << 4 | l) as char);
195                }
196            }
197            _ => result.push(b as char),
198        }
199    }
200    result
201}
202
203fn from_hex(b: u8) -> Option<u8> {
204    match b {
205        b'0'..=b'9' => Some(b - b'0'),
206        b'a'..=b'f' => Some(b - b'a' + 10),
207        b'A'..=b'F' => Some(b - b'A' + 10),
208        _ => None,
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn parse_amz_target_events() {
218        let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
219        assert_eq!(result.service, "events");
220        assert_eq!(result.action, "PutEvents");
221        assert_eq!(result.protocol, AwsProtocol::Json);
222    }
223
224    #[test]
225    fn parse_amz_target_ssm() {
226        let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
227        assert_eq!(result.service, "ssm");
228        assert_eq!(result.action, "GetParameter");
229    }
230
231    #[test]
232    fn parse_query_body_basic() {
233        let body = Bytes::from(
234            "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
235        );
236        let params = parse_query_body(&body);
237        assert_eq!(params.get("Action").unwrap(), "SendMessage");
238        assert_eq!(params.get("MessageBody").unwrap(), "hello");
239    }
240}