Skip to main content

fakecloud_core/
dispatch.rs

1use axum::body::Body;
2use axum::extract::{Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::protocol::{self, AwsProtocol};
9use crate::registry::ServiceRegistry;
10use crate::service::AwsRequest;
11
12/// The main dispatch handler. All HTTP requests come through here.
13pub async fn dispatch(
14    Extension(registry): Extension<Arc<ServiceRegistry>>,
15    Extension(config): Extension<Arc<DispatchConfig>>,
16    Query(query_params): Query<HashMap<String, String>>,
17    request: Request<Body>,
18) -> Response<Body> {
19    let request_id = uuid::Uuid::new_v4().to_string();
20
21    let (parts, body) = request.into_parts();
22    let body_bytes = match axum::body::to_bytes(body, 10 * 1024 * 1024).await {
23        Ok(b) => b,
24        Err(_) => {
25            return build_error_response(
26                StatusCode::PAYLOAD_TOO_LARGE,
27                "RequestEntityTooLarge",
28                "Request body too large",
29                &request_id,
30                AwsProtocol::Query,
31            );
32        }
33    };
34
35    // Detect service and action
36    let detected = match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
37        Some(d) => d,
38        None => {
39            // OPTIONS requests (CORS preflight) don't carry Authorization headers.
40            // Route them to S3 since S3 is the only REST service that handles CORS.
41            // Note: API Gateway CORS preflight is not fully supported in this emulator
42            // because we can't distinguish between S3 and API Gateway OPTIONS requests
43            // without additional context (in real AWS, they have different domains).
44            if parts.method == http::Method::OPTIONS {
45                protocol::DetectedRequest {
46                    service: "s3".to_string(),
47                    action: String::new(),
48                    protocol: AwsProtocol::Rest,
49                }
50            } else if !parts.uri.path().starts_with("/_") {
51                // Requests without AWS auth that don't match any service might be
52                // API Gateway execute API calls (plain HTTP without signatures).
53                // Route them to apigateway service which will validate if a matching
54                // API/stage exists. Skip special FakeCloud endpoints (/_*).
55                protocol::DetectedRequest {
56                    service: "apigateway".to_string(),
57                    action: String::new(),
58                    protocol: AwsProtocol::RestJson,
59                }
60            } else {
61                return build_error_response(
62                    StatusCode::BAD_REQUEST,
63                    "MissingAction",
64                    "Could not determine target service or action from request",
65                    &request_id,
66                    AwsProtocol::Query,
67                );
68            }
69        }
70    };
71
72    // Look up service
73    let service = match registry.get(&detected.service) {
74        Some(s) => s,
75        None => {
76            return build_error_response(
77                detected.protocol.error_status(),
78                "UnknownService",
79                &format!("Service '{}' is not available", detected.service),
80                &request_id,
81                detected.protocol,
82            );
83        }
84    };
85
86    // Extract region and access key from auth header
87    let sigv4_info = fakecloud_aws::sigv4::parse_sigv4(
88        parts
89            .headers
90            .get("authorization")
91            .and_then(|v| v.to_str().ok())
92            .unwrap_or(""),
93    );
94    let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
95    let region = sigv4_info
96        .map(|info| info.region)
97        .or_else(|| extract_region_from_user_agent(&parts.headers))
98        .unwrap_or_else(|| config.region.clone());
99
100    // Build path segments
101    let path = parts.uri.path().to_string();
102    let raw_query = parts.uri.query().unwrap_or("").to_string();
103    let path_segments: Vec<String> = path
104        .split('/')
105        .filter(|s| !s.is_empty())
106        .map(|s| s.to_string())
107        .collect();
108
109    // For JSON protocol, validate that non-empty bodies are valid JSON
110    if detected.protocol == AwsProtocol::Json
111        && !body_bytes.is_empty()
112        && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
113    {
114        return build_error_response(
115            StatusCode::BAD_REQUEST,
116            "SerializationException",
117            "Start of structure or map found where not expected",
118            &request_id,
119            AwsProtocol::Json,
120        );
121    }
122
123    // Merge query params with form body params for Query protocol
124    let mut all_params = query_params;
125    if detected.protocol == AwsProtocol::Query {
126        let body_params = protocol::parse_query_body(&body_bytes);
127        for (k, v) in body_params {
128            all_params.entry(k).or_insert(v);
129        }
130    }
131
132    let aws_request = AwsRequest {
133        service: detected.service.clone(),
134        action: detected.action.clone(),
135        region,
136        account_id: config.account_id.clone(),
137        request_id: request_id.clone(),
138        headers: parts.headers,
139        query_params: all_params,
140        body: body_bytes,
141        path_segments,
142        raw_path: path,
143        raw_query,
144        method: parts.method,
145        is_query_protocol: detected.protocol == AwsProtocol::Query,
146        access_key_id,
147    };
148
149    tracing::info!(
150        service = %aws_request.service,
151        action = %aws_request.action,
152        request_id = %aws_request.request_id,
153        "handling request"
154    );
155
156    match service.handle(aws_request).await {
157        Ok(resp) => {
158            let mut builder = Response::builder()
159                .status(resp.status)
160                .header("x-amzn-requestid", &request_id)
161                .header("x-amz-request-id", &request_id);
162
163            if !resp.content_type.is_empty() {
164                builder = builder.header("content-type", &resp.content_type);
165            }
166
167            for (k, v) in &resp.headers {
168                builder = builder.header(k, v);
169            }
170
171            builder.body(Body::from(resp.body)).unwrap()
172        }
173        Err(err) => {
174            tracing::warn!(
175                service = %detected.service,
176                action = %detected.action,
177                error = %err,
178                "request failed"
179            );
180            let error_headers = err.response_headers().to_vec();
181            let mut resp = build_error_response_with_fields(
182                err.status(),
183                err.code(),
184                &err.message(),
185                &request_id,
186                detected.protocol,
187                err.extra_fields(),
188            );
189            for (k, v) in &error_headers {
190                if let (Ok(name), Ok(val)) = (
191                    k.parse::<http::header::HeaderName>(),
192                    v.parse::<http::header::HeaderValue>(),
193                ) {
194                    resp.headers_mut().insert(name, val);
195                }
196            }
197            resp
198        }
199    }
200}
201
202/// Configuration passed to the dispatch handler.
203#[derive(Debug, Clone)]
204pub struct DispatchConfig {
205    pub region: String,
206    pub account_id: String,
207}
208
209/// Extract region from User-Agent header suffix `region/<region>`.
210fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
211    let ua = headers.get("user-agent")?.to_str().ok()?;
212    for part in ua.split_whitespace() {
213        if let Some(region) = part.strip_prefix("region/") {
214            if !region.is_empty() {
215                return Some(region.to_string());
216            }
217        }
218    }
219    None
220}
221
222fn build_error_response(
223    status: StatusCode,
224    code: &str,
225    message: &str,
226    request_id: &str,
227    protocol: AwsProtocol,
228) -> Response<Body> {
229    build_error_response_with_fields(status, code, message, request_id, protocol, &[])
230}
231
232fn build_error_response_with_fields(
233    status: StatusCode,
234    code: &str,
235    message: &str,
236    request_id: &str,
237    protocol: AwsProtocol,
238    extra_fields: &[(String, String)],
239) -> Response<Body> {
240    let (status, content_type, body) = match protocol {
241        AwsProtocol::Query => {
242            fakecloud_aws::error::xml_error_response(status, code, message, request_id)
243        }
244        AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
245            status,
246            code,
247            message,
248            request_id,
249            extra_fields,
250        ),
251        AwsProtocol::Json | AwsProtocol::RestJson => {
252            fakecloud_aws::error::json_error_response(status, code, message)
253        }
254    };
255
256    Response::builder()
257        .status(status)
258        .header("content-type", content_type)
259        .header("x-amzn-requestid", request_id)
260        .header("x-amz-request-id", request_id)
261        .body(Body::from(body))
262        .unwrap()
263}
264
265trait ProtocolExt {
266    fn error_status(&self) -> StatusCode;
267}
268
269impl ProtocolExt for AwsProtocol {
270    fn error_status(&self) -> StatusCode {
271        StatusCode::BAD_REQUEST
272    }
273}