Skip to main content

fakecloud_core/
dispatch.rs

1use axum::body::Body;
2use axum::extract::{Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::protocol::{self, AwsProtocol};
9use crate::registry::ServiceRegistry;
10use crate::service::AwsRequest;
11
12/// The main dispatch handler. All HTTP requests come through here.
13pub async fn dispatch(
14    Extension(registry): Extension<Arc<ServiceRegistry>>,
15    Extension(config): Extension<Arc<DispatchConfig>>,
16    Query(query_params): Query<HashMap<String, String>>,
17    request: Request<Body>,
18) -> Response<Body> {
19    let request_id = uuid::Uuid::new_v4().to_string();
20
21    let (parts, body) = request.into_parts();
22    let body_bytes = match axum::body::to_bytes(body, 10 * 1024 * 1024).await {
23        Ok(b) => b,
24        Err(_) => {
25            return build_error_response(
26                StatusCode::PAYLOAD_TOO_LARGE,
27                "RequestEntityTooLarge",
28                "Request body too large",
29                &request_id,
30                AwsProtocol::Query,
31            );
32        }
33    };
34
35    // Detect service and action
36    let detected = match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
37        Some(d) => d,
38        None => {
39            // OPTIONS requests (CORS preflight) don't carry Authorization headers.
40            // Route them to S3 since S3 is the only REST service that handles CORS.
41            if parts.method == http::Method::OPTIONS {
42                protocol::DetectedRequest {
43                    service: "s3".to_string(),
44                    action: String::new(),
45                    protocol: AwsProtocol::Rest,
46                }
47            } else {
48                return build_error_response(
49                    StatusCode::BAD_REQUEST,
50                    "MissingAction",
51                    "Could not determine target service or action from request",
52                    &request_id,
53                    AwsProtocol::Query,
54                );
55            }
56        }
57    };
58
59    // Look up service
60    let service = match registry.get(&detected.service) {
61        Some(s) => s,
62        None => {
63            return build_error_response(
64                detected.protocol.error_status(),
65                "UnknownService",
66                &format!("Service '{}' is not available", detected.service),
67                &request_id,
68                detected.protocol,
69            );
70        }
71    };
72
73    // Extract region and access key from auth header
74    let sigv4_info = fakecloud_aws::sigv4::parse_sigv4(
75        parts
76            .headers
77            .get("authorization")
78            .and_then(|v| v.to_str().ok())
79            .unwrap_or(""),
80    );
81    let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
82    let region = sigv4_info
83        .map(|info| info.region)
84        .or_else(|| extract_region_from_user_agent(&parts.headers))
85        .unwrap_or_else(|| config.region.clone());
86
87    // Build path segments
88    let path = parts.uri.path().to_string();
89    let raw_query = parts.uri.query().unwrap_or("").to_string();
90    let path_segments: Vec<String> = path
91        .split('/')
92        .filter(|s| !s.is_empty())
93        .map(|s| s.to_string())
94        .collect();
95
96    // For JSON protocol, validate that non-empty bodies are valid JSON
97    if detected.protocol == AwsProtocol::Json
98        && !body_bytes.is_empty()
99        && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
100    {
101        return build_error_response(
102            StatusCode::BAD_REQUEST,
103            "SerializationException",
104            "Start of structure or map found where not expected",
105            &request_id,
106            AwsProtocol::Json,
107        );
108    }
109
110    // Merge query params with form body params for Query protocol
111    let mut all_params = query_params;
112    if detected.protocol == AwsProtocol::Query {
113        let body_params = protocol::parse_query_body(&body_bytes);
114        for (k, v) in body_params {
115            all_params.entry(k).or_insert(v);
116        }
117    }
118
119    let aws_request = AwsRequest {
120        service: detected.service.clone(),
121        action: detected.action.clone(),
122        region,
123        account_id: config.account_id.clone(),
124        request_id: request_id.clone(),
125        headers: parts.headers,
126        query_params: all_params,
127        body: body_bytes,
128        path_segments,
129        raw_path: path,
130        raw_query,
131        method: parts.method,
132        is_query_protocol: detected.protocol == AwsProtocol::Query,
133        access_key_id,
134    };
135
136    tracing::info!(
137        service = %aws_request.service,
138        action = %aws_request.action,
139        request_id = %aws_request.request_id,
140        "handling request"
141    );
142
143    match service.handle(aws_request).await {
144        Ok(resp) => {
145            let mut builder = Response::builder()
146                .status(resp.status)
147                .header("x-amzn-requestid", &request_id)
148                .header("x-amz-request-id", &request_id);
149
150            if !resp.content_type.is_empty() {
151                builder = builder.header("content-type", &resp.content_type);
152            }
153
154            for (k, v) in &resp.headers {
155                builder = builder.header(k, v);
156            }
157
158            builder.body(Body::from(resp.body)).unwrap()
159        }
160        Err(err) => {
161            tracing::warn!(
162                service = %detected.service,
163                action = %detected.action,
164                error = %err,
165                "request failed"
166            );
167            let error_headers = err.response_headers().to_vec();
168            let mut resp = build_error_response_with_fields(
169                err.status(),
170                err.code(),
171                &err.message(),
172                &request_id,
173                detected.protocol,
174                err.extra_fields(),
175            );
176            for (k, v) in &error_headers {
177                if let (Ok(name), Ok(val)) = (
178                    k.parse::<http::header::HeaderName>(),
179                    v.parse::<http::header::HeaderValue>(),
180                ) {
181                    resp.headers_mut().insert(name, val);
182                }
183            }
184            resp
185        }
186    }
187}
188
189/// Configuration passed to the dispatch handler.
190#[derive(Debug, Clone)]
191pub struct DispatchConfig {
192    pub region: String,
193    pub account_id: String,
194}
195
196/// Extract region from User-Agent header suffix `region/<region>`.
197fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
198    let ua = headers.get("user-agent")?.to_str().ok()?;
199    for part in ua.split_whitespace() {
200        if let Some(region) = part.strip_prefix("region/") {
201            if !region.is_empty() {
202                return Some(region.to_string());
203            }
204        }
205    }
206    None
207}
208
209fn build_error_response(
210    status: StatusCode,
211    code: &str,
212    message: &str,
213    request_id: &str,
214    protocol: AwsProtocol,
215) -> Response<Body> {
216    build_error_response_with_fields(status, code, message, request_id, protocol, &[])
217}
218
219fn build_error_response_with_fields(
220    status: StatusCode,
221    code: &str,
222    message: &str,
223    request_id: &str,
224    protocol: AwsProtocol,
225    extra_fields: &[(String, String)],
226) -> Response<Body> {
227    let (status, content_type, body) = match protocol {
228        AwsProtocol::Query => {
229            fakecloud_aws::error::xml_error_response(status, code, message, request_id)
230        }
231        AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
232            status,
233            code,
234            message,
235            request_id,
236            extra_fields,
237        ),
238        AwsProtocol::Json | AwsProtocol::RestJson => {
239            fakecloud_aws::error::json_error_response(status, code, message)
240        }
241    };
242
243    Response::builder()
244        .status(status)
245        .header("content-type", content_type)
246        .header("x-amzn-requestid", request_id)
247        .header("x-amz-request-id", request_id)
248        .body(Body::from(body))
249        .unwrap()
250}
251
252trait ProtocolExt {
253    fn error_status(&self) -> StatusCode;
254}
255
256impl ProtocolExt for AwsProtocol {
257    fn error_status(&self) -> StatusCode {
258        StatusCode::BAD_REQUEST
259    }
260}