Skip to main content

fakecloud_core/
dispatch.rs

1use axum::body::Body;
2use axum::extract::{Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::protocol::{self, AwsProtocol};
9use crate::registry::ServiceRegistry;
10use crate::service::AwsRequest;
11
12/// The main dispatch handler. All HTTP requests come through here.
13pub async fn dispatch(
14    Extension(registry): Extension<Arc<ServiceRegistry>>,
15    Extension(config): Extension<Arc<DispatchConfig>>,
16    Query(query_params): Query<HashMap<String, String>>,
17    request: Request<Body>,
18) -> Response<Body> {
19    let request_id = uuid::Uuid::new_v4().to_string();
20
21    let (parts, body) = request.into_parts();
22    let body_bytes = match axum::body::to_bytes(body, 10 * 1024 * 1024).await {
23        Ok(b) => b,
24        Err(_) => {
25            return build_error_response(
26                StatusCode::PAYLOAD_TOO_LARGE,
27                "RequestEntityTooLarge",
28                "Request body too large",
29                &request_id,
30                AwsProtocol::Query,
31            );
32        }
33    };
34
35    // Detect service and action
36    let detected = match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
37        Some(d) => d,
38        None => {
39            return build_error_response(
40                StatusCode::BAD_REQUEST,
41                "MissingAction",
42                "Could not determine target service or action from request",
43                &request_id,
44                AwsProtocol::Query,
45            );
46        }
47    };
48
49    // Look up service
50    let service = match registry.get(&detected.service) {
51        Some(s) => s,
52        None => {
53            return build_error_response(
54                detected.protocol.error_status(),
55                "UnknownService",
56                &format!("Service '{}' is not available", detected.service),
57                &request_id,
58                detected.protocol,
59            );
60        }
61    };
62
63    // Extract region and access key from auth header
64    let sigv4_info = fakecloud_aws::sigv4::parse_sigv4(
65        parts
66            .headers
67            .get("authorization")
68            .and_then(|v| v.to_str().ok())
69            .unwrap_or(""),
70    );
71    let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
72    let region = sigv4_info
73        .map(|info| info.region)
74        .or_else(|| extract_region_from_user_agent(&parts.headers))
75        .unwrap_or_else(|| config.region.clone());
76
77    // Build path segments
78    let path = parts.uri.path().to_string();
79    let path_segments: Vec<String> = path
80        .split('/')
81        .filter(|s| !s.is_empty())
82        .map(|s| s.to_string())
83        .collect();
84
85    // For JSON protocol, validate that non-empty bodies are valid JSON
86    if detected.protocol == AwsProtocol::Json
87        && !body_bytes.is_empty()
88        && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
89    {
90        return build_error_response(
91            StatusCode::BAD_REQUEST,
92            "SerializationException",
93            "Start of structure or map found where not expected",
94            &request_id,
95            AwsProtocol::Json,
96        );
97    }
98
99    // Merge query params with form body params for Query protocol
100    let mut all_params = query_params;
101    if detected.protocol == AwsProtocol::Query {
102        let body_params = protocol::parse_query_body(&body_bytes);
103        for (k, v) in body_params {
104            all_params.entry(k).or_insert(v);
105        }
106    }
107
108    let aws_request = AwsRequest {
109        service: detected.service.clone(),
110        action: detected.action.clone(),
111        region,
112        account_id: config.account_id.clone(),
113        request_id: request_id.clone(),
114        headers: parts.headers,
115        query_params: all_params,
116        body: body_bytes,
117        path_segments,
118        raw_path: path,
119        method: parts.method,
120        is_query_protocol: detected.protocol == AwsProtocol::Query,
121        access_key_id,
122    };
123
124    tracing::info!(
125        service = %aws_request.service,
126        action = %aws_request.action,
127        request_id = %aws_request.request_id,
128        "handling request"
129    );
130
131    match service.handle(aws_request).await {
132        Ok(resp) => {
133            let mut builder = Response::builder()
134                .status(resp.status)
135                .header("x-amzn-requestid", &request_id)
136                .header("x-amz-request-id", &request_id);
137
138            if !resp.content_type.is_empty() {
139                builder = builder.header("content-type", &resp.content_type);
140            }
141
142            for (k, v) in &resp.headers {
143                builder = builder.header(k, v);
144            }
145
146            builder.body(Body::from(resp.body)).unwrap()
147        }
148        Err(err) => {
149            tracing::warn!(
150                service = %detected.service,
151                action = %detected.action,
152                error = %err,
153                "request failed"
154            );
155            let error_headers = err.response_headers().to_vec();
156            let mut resp = build_error_response_with_fields(
157                err.status(),
158                err.code(),
159                &err.message(),
160                &request_id,
161                detected.protocol,
162                err.extra_fields(),
163            );
164            for (k, v) in &error_headers {
165                if let (Ok(name), Ok(val)) = (
166                    k.parse::<http::header::HeaderName>(),
167                    v.parse::<http::header::HeaderValue>(),
168                ) {
169                    resp.headers_mut().insert(name, val);
170                }
171            }
172            resp
173        }
174    }
175}
176
177/// Configuration passed to the dispatch handler.
178#[derive(Debug, Clone)]
179pub struct DispatchConfig {
180    pub region: String,
181    pub account_id: String,
182}
183
184/// Extract region from User-Agent header suffix `region/<region>`.
185fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
186    let ua = headers.get("user-agent")?.to_str().ok()?;
187    for part in ua.split_whitespace() {
188        if let Some(region) = part.strip_prefix("region/") {
189            if !region.is_empty() {
190                return Some(region.to_string());
191            }
192        }
193    }
194    None
195}
196
197fn build_error_response(
198    status: StatusCode,
199    code: &str,
200    message: &str,
201    request_id: &str,
202    protocol: AwsProtocol,
203) -> Response<Body> {
204    build_error_response_with_fields(status, code, message, request_id, protocol, &[])
205}
206
207fn build_error_response_with_fields(
208    status: StatusCode,
209    code: &str,
210    message: &str,
211    request_id: &str,
212    protocol: AwsProtocol,
213    extra_fields: &[(String, String)],
214) -> Response<Body> {
215    let (status, content_type, body) = match protocol {
216        AwsProtocol::Query => {
217            fakecloud_aws::error::xml_error_response(status, code, message, request_id)
218        }
219        AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
220            status,
221            code,
222            message,
223            request_id,
224            extra_fields,
225        ),
226        AwsProtocol::Json => fakecloud_aws::error::json_error_response(status, code, message),
227    };
228
229    Response::builder()
230        .status(status)
231        .header("content-type", content_type)
232        .header("x-amzn-requestid", request_id)
233        .header("x-amz-request-id", request_id)
234        .body(Body::from(body))
235        .unwrap()
236}
237
238trait ProtocolExt {
239    fn error_status(&self) -> StatusCode;
240}
241
242impl ProtocolExt for AwsProtocol {
243    fn error_status(&self) -> StatusCode {
244        StatusCode::BAD_REQUEST
245    }
246}