1use axum::body::Body;
2use axum::extract::{Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::protocol::{self, AwsProtocol};
9use crate::registry::ServiceRegistry;
10use crate::service::{AwsRequest, ResponseBody};
11
12pub async fn dispatch(
14 Extension(registry): Extension<Arc<ServiceRegistry>>,
15 Extension(config): Extension<Arc<DispatchConfig>>,
16 Query(query_params): Query<HashMap<String, String>>,
17 request: Request<Body>,
18) -> Response<Body> {
19 let request_id = uuid::Uuid::new_v4().to_string();
20
21 let (parts, body) = request.into_parts();
22 const MAX_BODY_BYTES: usize = 128 * 1024 * 1024;
28 let body_bytes = match axum::body::to_bytes(body, MAX_BODY_BYTES).await {
29 Ok(b) => b,
30 Err(_) => {
31 return build_error_response(
32 StatusCode::PAYLOAD_TOO_LARGE,
33 "RequestEntityTooLarge",
34 "Request body too large",
35 &request_id,
36 AwsProtocol::Query,
37 );
38 }
39 };
40
41 let detected = match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
43 Some(d) => d,
44 None => {
45 if parts.method == http::Method::OPTIONS {
51 protocol::DetectedRequest {
52 service: "s3".to_string(),
53 action: String::new(),
54 protocol: AwsProtocol::Rest,
55 }
56 } else if !parts.uri.path().starts_with("/_") {
57 protocol::DetectedRequest {
62 service: "apigateway".to_string(),
63 action: String::new(),
64 protocol: AwsProtocol::RestJson,
65 }
66 } else {
67 return build_error_response(
68 StatusCode::BAD_REQUEST,
69 "MissingAction",
70 "Could not determine target service or action from request",
71 &request_id,
72 AwsProtocol::Query,
73 );
74 }
75 }
76 };
77
78 let service = match registry.get(&detected.service) {
80 Some(s) => s,
81 None => {
82 return build_error_response(
83 detected.protocol.error_status(),
84 "UnknownService",
85 &format!("Service '{}' is not available", detected.service),
86 &request_id,
87 detected.protocol,
88 );
89 }
90 };
91
92 let sigv4_info = fakecloud_aws::sigv4::parse_sigv4(
94 parts
95 .headers
96 .get("authorization")
97 .and_then(|v| v.to_str().ok())
98 .unwrap_or(""),
99 );
100 let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
101 let region = sigv4_info
102 .map(|info| info.region)
103 .or_else(|| extract_region_from_user_agent(&parts.headers))
104 .unwrap_or_else(|| config.region.clone());
105
106 let path = parts.uri.path().to_string();
108 let raw_query = parts.uri.query().unwrap_or("").to_string();
109 let path_segments: Vec<String> = path
110 .split('/')
111 .filter(|s| !s.is_empty())
112 .map(|s| s.to_string())
113 .collect();
114
115 if detected.protocol == AwsProtocol::Json
117 && !body_bytes.is_empty()
118 && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
119 {
120 return build_error_response(
121 StatusCode::BAD_REQUEST,
122 "SerializationException",
123 "Start of structure or map found where not expected",
124 &request_id,
125 AwsProtocol::Json,
126 );
127 }
128
129 let mut all_params = query_params;
131 if detected.protocol == AwsProtocol::Query {
132 let body_params = protocol::parse_query_body(&body_bytes);
133 for (k, v) in body_params {
134 all_params.entry(k).or_insert(v);
135 }
136 }
137
138 let aws_request = AwsRequest {
139 service: detected.service.clone(),
140 action: detected.action.clone(),
141 region,
142 account_id: config.account_id.clone(),
143 request_id: request_id.clone(),
144 headers: parts.headers,
145 query_params: all_params,
146 body: body_bytes,
147 path_segments,
148 raw_path: path,
149 raw_query,
150 method: parts.method,
151 is_query_protocol: detected.protocol == AwsProtocol::Query,
152 access_key_id,
153 };
154
155 tracing::info!(
156 service = %aws_request.service,
157 action = %aws_request.action,
158 request_id = %aws_request.request_id,
159 "handling request"
160 );
161
162 match service.handle(aws_request).await {
163 Ok(resp) => {
164 let mut builder = Response::builder()
165 .status(resp.status)
166 .header("x-amzn-requestid", &request_id)
167 .header("x-amz-request-id", &request_id);
168
169 if !resp.content_type.is_empty() {
170 builder = builder.header("content-type", &resp.content_type);
171 }
172
173 let has_content_length = resp
174 .headers
175 .iter()
176 .any(|(k, _)| k.as_str().eq_ignore_ascii_case("content-length"));
177
178 for (k, v) in &resp.headers {
179 builder = builder.header(k, v);
180 }
181
182 match resp.body {
183 ResponseBody::Bytes(b) => builder.body(Body::from(b)).unwrap(),
184 ResponseBody::File { file, size } => {
185 let stream = tokio_util::io::ReaderStream::new(file);
186 let body = Body::from_stream(stream);
187 if !has_content_length {
188 builder = builder.header("content-length", size.to_string());
189 }
190 builder.body(body).unwrap()
191 }
192 }
193 }
194 Err(err) => {
195 tracing::warn!(
196 service = %detected.service,
197 action = %detected.action,
198 error = %err,
199 "request failed"
200 );
201 let error_headers = err.response_headers().to_vec();
202 let mut resp = build_error_response_with_fields(
203 err.status(),
204 err.code(),
205 &err.message(),
206 &request_id,
207 detected.protocol,
208 err.extra_fields(),
209 );
210 for (k, v) in &error_headers {
211 if let (Ok(name), Ok(val)) = (
212 k.parse::<http::header::HeaderName>(),
213 v.parse::<http::header::HeaderValue>(),
214 ) {
215 resp.headers_mut().insert(name, val);
216 }
217 }
218 resp
219 }
220 }
221}
222
223#[derive(Debug, Clone)]
225pub struct DispatchConfig {
226 pub region: String,
227 pub account_id: String,
228}
229
230fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
232 let ua = headers.get("user-agent")?.to_str().ok()?;
233 for part in ua.split_whitespace() {
234 if let Some(region) = part.strip_prefix("region/") {
235 if !region.is_empty() {
236 return Some(region.to_string());
237 }
238 }
239 }
240 None
241}
242
243fn build_error_response(
244 status: StatusCode,
245 code: &str,
246 message: &str,
247 request_id: &str,
248 protocol: AwsProtocol,
249) -> Response<Body> {
250 build_error_response_with_fields(status, code, message, request_id, protocol, &[])
251}
252
253fn build_error_response_with_fields(
254 status: StatusCode,
255 code: &str,
256 message: &str,
257 request_id: &str,
258 protocol: AwsProtocol,
259 extra_fields: &[(String, String)],
260) -> Response<Body> {
261 let (status, content_type, body) = match protocol {
262 AwsProtocol::Query => {
263 fakecloud_aws::error::xml_error_response(status, code, message, request_id)
264 }
265 AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
266 status,
267 code,
268 message,
269 request_id,
270 extra_fields,
271 ),
272 AwsProtocol::Json | AwsProtocol::RestJson => {
273 fakecloud_aws::error::json_error_response(status, code, message)
274 }
275 };
276
277 Response::builder()
278 .status(status)
279 .header("content-type", content_type)
280 .header("x-amzn-requestid", request_id)
281 .header("x-amz-request-id", request_id)
282 .body(Body::from(body))
283 .unwrap()
284}
285
286trait ProtocolExt {
287 fn error_status(&self) -> StatusCode;
288}
289
290impl ProtocolExt for AwsProtocol {
291 fn error_status(&self) -> StatusCode {
292 StatusCode::BAD_REQUEST
293 }
294}