Skip to main content

fakecloud_core/
dispatch.rs

1use axum::body::Body;
2use axum::extract::{Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::protocol::{self, AwsProtocol};
9use crate::registry::ServiceRegistry;
10use crate::service::{AwsRequest, ResponseBody};
11
12/// The main dispatch handler. All HTTP requests come through here.
13pub async fn dispatch(
14    Extension(registry): Extension<Arc<ServiceRegistry>>,
15    Extension(config): Extension<Arc<DispatchConfig>>,
16    Query(query_params): Query<HashMap<String, String>>,
17    request: Request<Body>,
18) -> Response<Body> {
19    let request_id = uuid::Uuid::new_v4().to_string();
20
21    let (parts, body) = request.into_parts();
22    // TODO: plumb streaming request bodies end-to-end to remove the cap.
23    // 128 MiB comfortably covers every legitimate single-PutObject (AWS
24    // recommends multipart above ~100 MiB) and each multipart part is
25    // dispatched through here separately, so a 20 GiB upload stays under this
26    // limit per-request.
27    const MAX_BODY_BYTES: usize = 128 * 1024 * 1024;
28    let body_bytes = match axum::body::to_bytes(body, MAX_BODY_BYTES).await {
29        Ok(b) => b,
30        Err(_) => {
31            return build_error_response(
32                StatusCode::PAYLOAD_TOO_LARGE,
33                "RequestEntityTooLarge",
34                "Request body too large",
35                &request_id,
36                AwsProtocol::Query,
37            );
38        }
39    };
40
41    // Detect service and action
42    let detected = match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
43        Some(d) => d,
44        None => {
45            // OPTIONS requests (CORS preflight) don't carry Authorization headers.
46            // Route them to S3 since S3 is the only REST service that handles CORS.
47            // Note: API Gateway CORS preflight is not fully supported in this emulator
48            // because we can't distinguish between S3 and API Gateway OPTIONS requests
49            // without additional context (in real AWS, they have different domains).
50            if parts.method == http::Method::OPTIONS {
51                protocol::DetectedRequest {
52                    service: "s3".to_string(),
53                    action: String::new(),
54                    protocol: AwsProtocol::Rest,
55                }
56            } else if !parts.uri.path().starts_with("/_") {
57                // Requests without AWS auth that don't match any service might be
58                // API Gateway execute API calls (plain HTTP without signatures).
59                // Route them to apigateway service which will validate if a matching
60                // API/stage exists. Skip special FakeCloud endpoints (/_*).
61                protocol::DetectedRequest {
62                    service: "apigateway".to_string(),
63                    action: String::new(),
64                    protocol: AwsProtocol::RestJson,
65                }
66            } else {
67                return build_error_response(
68                    StatusCode::BAD_REQUEST,
69                    "MissingAction",
70                    "Could not determine target service or action from request",
71                    &request_id,
72                    AwsProtocol::Query,
73                );
74            }
75        }
76    };
77
78    // Look up service
79    let service = match registry.get(&detected.service) {
80        Some(s) => s,
81        None => {
82            return build_error_response(
83                detected.protocol.error_status(),
84                "UnknownService",
85                &format!("Service '{}' is not available", detected.service),
86                &request_id,
87                detected.protocol,
88            );
89        }
90    };
91
92    // Extract region and access key from auth header
93    let sigv4_info = fakecloud_aws::sigv4::parse_sigv4(
94        parts
95            .headers
96            .get("authorization")
97            .and_then(|v| v.to_str().ok())
98            .unwrap_or(""),
99    );
100    let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
101    let region = sigv4_info
102        .map(|info| info.region)
103        .or_else(|| extract_region_from_user_agent(&parts.headers))
104        .unwrap_or_else(|| config.region.clone());
105
106    // Build path segments
107    let path = parts.uri.path().to_string();
108    let raw_query = parts.uri.query().unwrap_or("").to_string();
109    let path_segments: Vec<String> = path
110        .split('/')
111        .filter(|s| !s.is_empty())
112        .map(|s| s.to_string())
113        .collect();
114
115    // For JSON protocol, validate that non-empty bodies are valid JSON
116    if detected.protocol == AwsProtocol::Json
117        && !body_bytes.is_empty()
118        && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
119    {
120        return build_error_response(
121            StatusCode::BAD_REQUEST,
122            "SerializationException",
123            "Start of structure or map found where not expected",
124            &request_id,
125            AwsProtocol::Json,
126        );
127    }
128
129    // Merge query params with form body params for Query protocol
130    let mut all_params = query_params;
131    if detected.protocol == AwsProtocol::Query {
132        let body_params = protocol::parse_query_body(&body_bytes);
133        for (k, v) in body_params {
134            all_params.entry(k).or_insert(v);
135        }
136    }
137
138    let aws_request = AwsRequest {
139        service: detected.service.clone(),
140        action: detected.action.clone(),
141        region,
142        account_id: config.account_id.clone(),
143        request_id: request_id.clone(),
144        headers: parts.headers,
145        query_params: all_params,
146        body: body_bytes,
147        path_segments,
148        raw_path: path,
149        raw_query,
150        method: parts.method,
151        is_query_protocol: detected.protocol == AwsProtocol::Query,
152        access_key_id,
153    };
154
155    tracing::info!(
156        service = %aws_request.service,
157        action = %aws_request.action,
158        request_id = %aws_request.request_id,
159        "handling request"
160    );
161
162    match service.handle(aws_request).await {
163        Ok(resp) => {
164            let mut builder = Response::builder()
165                .status(resp.status)
166                .header("x-amzn-requestid", &request_id)
167                .header("x-amz-request-id", &request_id);
168
169            if !resp.content_type.is_empty() {
170                builder = builder.header("content-type", &resp.content_type);
171            }
172
173            let has_content_length = resp
174                .headers
175                .iter()
176                .any(|(k, _)| k.as_str().eq_ignore_ascii_case("content-length"));
177
178            for (k, v) in &resp.headers {
179                builder = builder.header(k, v);
180            }
181
182            match resp.body {
183                ResponseBody::Bytes(b) => builder.body(Body::from(b)).unwrap(),
184                ResponseBody::File { file, size } => {
185                    let stream = tokio_util::io::ReaderStream::new(file);
186                    let body = Body::from_stream(stream);
187                    if !has_content_length {
188                        builder = builder.header("content-length", size.to_string());
189                    }
190                    builder.body(body).unwrap()
191                }
192            }
193        }
194        Err(err) => {
195            tracing::warn!(
196                service = %detected.service,
197                action = %detected.action,
198                error = %err,
199                "request failed"
200            );
201            let error_headers = err.response_headers().to_vec();
202            let mut resp = build_error_response_with_fields(
203                err.status(),
204                err.code(),
205                &err.message(),
206                &request_id,
207                detected.protocol,
208                err.extra_fields(),
209            );
210            for (k, v) in &error_headers {
211                if let (Ok(name), Ok(val)) = (
212                    k.parse::<http::header::HeaderName>(),
213                    v.parse::<http::header::HeaderValue>(),
214                ) {
215                    resp.headers_mut().insert(name, val);
216                }
217            }
218            resp
219        }
220    }
221}
222
223/// Configuration passed to the dispatch handler.
224#[derive(Debug, Clone)]
225pub struct DispatchConfig {
226    pub region: String,
227    pub account_id: String,
228}
229
230/// Extract region from User-Agent header suffix `region/<region>`.
231fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
232    let ua = headers.get("user-agent")?.to_str().ok()?;
233    for part in ua.split_whitespace() {
234        if let Some(region) = part.strip_prefix("region/") {
235            if !region.is_empty() {
236                return Some(region.to_string());
237            }
238        }
239    }
240    None
241}
242
243fn build_error_response(
244    status: StatusCode,
245    code: &str,
246    message: &str,
247    request_id: &str,
248    protocol: AwsProtocol,
249) -> Response<Body> {
250    build_error_response_with_fields(status, code, message, request_id, protocol, &[])
251}
252
253fn build_error_response_with_fields(
254    status: StatusCode,
255    code: &str,
256    message: &str,
257    request_id: &str,
258    protocol: AwsProtocol,
259    extra_fields: &[(String, String)],
260) -> Response<Body> {
261    let (status, content_type, body) = match protocol {
262        AwsProtocol::Query => {
263            fakecloud_aws::error::xml_error_response(status, code, message, request_id)
264        }
265        AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
266            status,
267            code,
268            message,
269            request_id,
270            extra_fields,
271        ),
272        AwsProtocol::Json | AwsProtocol::RestJson => {
273            fakecloud_aws::error::json_error_response(status, code, message)
274        }
275    };
276
277    Response::builder()
278        .status(status)
279        .header("content-type", content_type)
280        .header("x-amzn-requestid", request_id)
281        .header("x-amz-request-id", request_id)
282        .body(Body::from(body))
283        .unwrap()
284}
285
286trait ProtocolExt {
287    fn error_status(&self) -> StatusCode;
288}
289
290impl ProtocolExt for AwsProtocol {
291    fn error_status(&self) -> StatusCode {
292        StatusCode::BAD_REQUEST
293    }
294}