1use axum::body::Body;
2use axum::extract::{Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::protocol::{self, AwsProtocol};
9use crate::registry::ServiceRegistry;
10use crate::service::AwsRequest;
11
12pub async fn dispatch(
14 Extension(registry): Extension<Arc<ServiceRegistry>>,
15 Extension(config): Extension<Arc<DispatchConfig>>,
16 Query(query_params): Query<HashMap<String, String>>,
17 request: Request<Body>,
18) -> Response<Body> {
19 let request_id = uuid::Uuid::new_v4().to_string();
20
21 let (parts, body) = request.into_parts();
22 let body_bytes = match axum::body::to_bytes(body, 10 * 1024 * 1024).await {
23 Ok(b) => b,
24 Err(_) => {
25 return build_error_response(
26 StatusCode::PAYLOAD_TOO_LARGE,
27 "RequestEntityTooLarge",
28 "Request body too large",
29 &request_id,
30 AwsProtocol::Query,
31 );
32 }
33 };
34
35 let detected = match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
37 Some(d) => d,
38 None => {
39 if parts.method == http::Method::OPTIONS {
45 protocol::DetectedRequest {
46 service: "s3".to_string(),
47 action: String::new(),
48 protocol: AwsProtocol::Rest,
49 }
50 } else if !parts.uri.path().starts_with("/_") {
51 protocol::DetectedRequest {
56 service: "apigateway".to_string(),
57 action: String::new(),
58 protocol: AwsProtocol::RestJson,
59 }
60 } else {
61 return build_error_response(
62 StatusCode::BAD_REQUEST,
63 "MissingAction",
64 "Could not determine target service or action from request",
65 &request_id,
66 AwsProtocol::Query,
67 );
68 }
69 }
70 };
71
72 let service = match registry.get(&detected.service) {
74 Some(s) => s,
75 None => {
76 return build_error_response(
77 detected.protocol.error_status(),
78 "UnknownService",
79 &format!("Service '{}' is not available", detected.service),
80 &request_id,
81 detected.protocol,
82 );
83 }
84 };
85
86 let sigv4_info = fakecloud_aws::sigv4::parse_sigv4(
88 parts
89 .headers
90 .get("authorization")
91 .and_then(|v| v.to_str().ok())
92 .unwrap_or(""),
93 );
94 let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
95 let region = sigv4_info
96 .map(|info| info.region)
97 .or_else(|| extract_region_from_user_agent(&parts.headers))
98 .unwrap_or_else(|| config.region.clone());
99
100 let path = parts.uri.path().to_string();
102 let raw_query = parts.uri.query().unwrap_or("").to_string();
103 let path_segments: Vec<String> = path
104 .split('/')
105 .filter(|s| !s.is_empty())
106 .map(|s| s.to_string())
107 .collect();
108
109 if detected.protocol == AwsProtocol::Json
111 && !body_bytes.is_empty()
112 && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
113 {
114 return build_error_response(
115 StatusCode::BAD_REQUEST,
116 "SerializationException",
117 "Start of structure or map found where not expected",
118 &request_id,
119 AwsProtocol::Json,
120 );
121 }
122
123 let mut all_params = query_params;
125 if detected.protocol == AwsProtocol::Query {
126 let body_params = protocol::parse_query_body(&body_bytes);
127 for (k, v) in body_params {
128 all_params.entry(k).or_insert(v);
129 }
130 }
131
132 let aws_request = AwsRequest {
133 service: detected.service.clone(),
134 action: detected.action.clone(),
135 region,
136 account_id: config.account_id.clone(),
137 request_id: request_id.clone(),
138 headers: parts.headers,
139 query_params: all_params,
140 body: body_bytes,
141 path_segments,
142 raw_path: path,
143 raw_query,
144 method: parts.method,
145 is_query_protocol: detected.protocol == AwsProtocol::Query,
146 access_key_id,
147 };
148
149 tracing::info!(
150 service = %aws_request.service,
151 action = %aws_request.action,
152 request_id = %aws_request.request_id,
153 "handling request"
154 );
155
156 match service.handle(aws_request).await {
157 Ok(resp) => {
158 let mut builder = Response::builder()
159 .status(resp.status)
160 .header("x-amzn-requestid", &request_id)
161 .header("x-amz-request-id", &request_id);
162
163 if !resp.content_type.is_empty() {
164 builder = builder.header("content-type", &resp.content_type);
165 }
166
167 for (k, v) in &resp.headers {
168 builder = builder.header(k, v);
169 }
170
171 builder.body(Body::from(resp.body)).unwrap()
172 }
173 Err(err) => {
174 tracing::warn!(
175 service = %detected.service,
176 action = %detected.action,
177 error = %err,
178 "request failed"
179 );
180 let error_headers = err.response_headers().to_vec();
181 let mut resp = build_error_response_with_fields(
182 err.status(),
183 err.code(),
184 &err.message(),
185 &request_id,
186 detected.protocol,
187 err.extra_fields(),
188 );
189 for (k, v) in &error_headers {
190 if let (Ok(name), Ok(val)) = (
191 k.parse::<http::header::HeaderName>(),
192 v.parse::<http::header::HeaderValue>(),
193 ) {
194 resp.headers_mut().insert(name, val);
195 }
196 }
197 resp
198 }
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct DispatchConfig {
205 pub region: String,
206 pub account_id: String,
207}
208
209fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
211 let ua = headers.get("user-agent")?.to_str().ok()?;
212 for part in ua.split_whitespace() {
213 if let Some(region) = part.strip_prefix("region/") {
214 if !region.is_empty() {
215 return Some(region.to_string());
216 }
217 }
218 }
219 None
220}
221
222fn build_error_response(
223 status: StatusCode,
224 code: &str,
225 message: &str,
226 request_id: &str,
227 protocol: AwsProtocol,
228) -> Response<Body> {
229 build_error_response_with_fields(status, code, message, request_id, protocol, &[])
230}
231
232fn build_error_response_with_fields(
233 status: StatusCode,
234 code: &str,
235 message: &str,
236 request_id: &str,
237 protocol: AwsProtocol,
238 extra_fields: &[(String, String)],
239) -> Response<Body> {
240 let (status, content_type, body) = match protocol {
241 AwsProtocol::Query => {
242 fakecloud_aws::error::xml_error_response(status, code, message, request_id)
243 }
244 AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
245 status,
246 code,
247 message,
248 request_id,
249 extra_fields,
250 ),
251 AwsProtocol::Json | AwsProtocol::RestJson => {
252 fakecloud_aws::error::json_error_response(status, code, message)
253 }
254 };
255
256 Response::builder()
257 .status(status)
258 .header("content-type", content_type)
259 .header("x-amzn-requestid", request_id)
260 .header("x-amz-request-id", request_id)
261 .body(Body::from(body))
262 .unwrap()
263}
264
265trait ProtocolExt {
266 fn error_status(&self) -> StatusCode;
267}
268
269impl ProtocolExt for AwsProtocol {
270 fn error_status(&self) -> StatusCode {
271 StatusCode::BAD_REQUEST
272 }
273}