1use axum::body::Body;
2use axum::extract::{Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::protocol::{self, AwsProtocol};
9use crate::registry::ServiceRegistry;
10use crate::service::AwsRequest;
11
12pub async fn dispatch(
14 Extension(registry): Extension<Arc<ServiceRegistry>>,
15 Extension(config): Extension<Arc<DispatchConfig>>,
16 Query(query_params): Query<HashMap<String, String>>,
17 request: Request<Body>,
18) -> Response<Body> {
19 let request_id = uuid::Uuid::new_v4().to_string();
20
21 let (parts, body) = request.into_parts();
22 let body_bytes = match axum::body::to_bytes(body, 10 * 1024 * 1024).await {
23 Ok(b) => b,
24 Err(_) => {
25 return build_error_response(
26 StatusCode::PAYLOAD_TOO_LARGE,
27 "RequestEntityTooLarge",
28 "Request body too large",
29 &request_id,
30 AwsProtocol::Query,
31 );
32 }
33 };
34
35 let detected = match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
37 Some(d) => d,
38 None => {
39 return build_error_response(
40 StatusCode::BAD_REQUEST,
41 "MissingAction",
42 "Could not determine target service or action from request",
43 &request_id,
44 AwsProtocol::Query,
45 );
46 }
47 };
48
49 let service = match registry.get(&detected.service) {
51 Some(s) => s,
52 None => {
53 return build_error_response(
54 detected.protocol.error_status(),
55 "UnknownService",
56 &format!("Service '{}' is not available", detected.service),
57 &request_id,
58 detected.protocol,
59 );
60 }
61 };
62
63 let sigv4_info = fakecloud_aws::sigv4::parse_sigv4(
65 parts
66 .headers
67 .get("authorization")
68 .and_then(|v| v.to_str().ok())
69 .unwrap_or(""),
70 );
71 let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
72 let region = sigv4_info
73 .map(|info| info.region)
74 .or_else(|| extract_region_from_user_agent(&parts.headers))
75 .unwrap_or_else(|| config.region.clone());
76
77 let path = parts.uri.path().to_string();
79 let path_segments: Vec<String> = path
80 .split('/')
81 .filter(|s| !s.is_empty())
82 .map(|s| s.to_string())
83 .collect();
84
85 if detected.protocol == AwsProtocol::Json
87 && !body_bytes.is_empty()
88 && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
89 {
90 return build_error_response(
91 StatusCode::BAD_REQUEST,
92 "SerializationException",
93 "Start of structure or map found where not expected",
94 &request_id,
95 AwsProtocol::Json,
96 );
97 }
98
99 let mut all_params = query_params;
101 if detected.protocol == AwsProtocol::Query {
102 let body_params = protocol::parse_query_body(&body_bytes);
103 for (k, v) in body_params {
104 all_params.entry(k).or_insert(v);
105 }
106 }
107
108 let aws_request = AwsRequest {
109 service: detected.service.clone(),
110 action: detected.action.clone(),
111 region,
112 account_id: config.account_id.clone(),
113 request_id: request_id.clone(),
114 headers: parts.headers,
115 query_params: all_params,
116 body: body_bytes,
117 path_segments,
118 raw_path: path,
119 method: parts.method,
120 is_query_protocol: detected.protocol == AwsProtocol::Query,
121 access_key_id,
122 };
123
124 tracing::info!(
125 service = %aws_request.service,
126 action = %aws_request.action,
127 request_id = %aws_request.request_id,
128 "handling request"
129 );
130
131 match service.handle(aws_request).await {
132 Ok(resp) => {
133 let mut builder = Response::builder()
134 .status(resp.status)
135 .header("x-amzn-requestid", &request_id)
136 .header("x-amz-request-id", &request_id);
137
138 if !resp.content_type.is_empty() {
139 builder = builder.header("content-type", &resp.content_type);
140 }
141
142 for (k, v) in &resp.headers {
143 builder = builder.header(k, v);
144 }
145
146 builder.body(Body::from(resp.body)).unwrap()
147 }
148 Err(err) => {
149 tracing::warn!(
150 service = %detected.service,
151 action = %detected.action,
152 error = %err,
153 "request failed"
154 );
155 let error_headers = err.response_headers().to_vec();
156 let mut resp = build_error_response_with_fields(
157 err.status(),
158 err.code(),
159 &err.message(),
160 &request_id,
161 detected.protocol,
162 err.extra_fields(),
163 );
164 for (k, v) in &error_headers {
165 if let (Ok(name), Ok(val)) = (
166 k.parse::<http::header::HeaderName>(),
167 v.parse::<http::header::HeaderValue>(),
168 ) {
169 resp.headers_mut().insert(name, val);
170 }
171 }
172 resp
173 }
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct DispatchConfig {
180 pub region: String,
181 pub account_id: String,
182}
183
184fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
186 let ua = headers.get("user-agent")?.to_str().ok()?;
187 for part in ua.split_whitespace() {
188 if let Some(region) = part.strip_prefix("region/") {
189 if !region.is_empty() {
190 return Some(region.to_string());
191 }
192 }
193 }
194 None
195}
196
197fn build_error_response(
198 status: StatusCode,
199 code: &str,
200 message: &str,
201 request_id: &str,
202 protocol: AwsProtocol,
203) -> Response<Body> {
204 build_error_response_with_fields(status, code, message, request_id, protocol, &[])
205}
206
207fn build_error_response_with_fields(
208 status: StatusCode,
209 code: &str,
210 message: &str,
211 request_id: &str,
212 protocol: AwsProtocol,
213 extra_fields: &[(String, String)],
214) -> Response<Body> {
215 let (status, content_type, body) = match protocol {
216 AwsProtocol::Query => {
217 fakecloud_aws::error::xml_error_response(status, code, message, request_id)
218 }
219 AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
220 status,
221 code,
222 message,
223 request_id,
224 extra_fields,
225 ),
226 AwsProtocol::Json => fakecloud_aws::error::json_error_response(status, code, message),
227 };
228
229 Response::builder()
230 .status(status)
231 .header("content-type", content_type)
232 .header("x-amzn-requestid", request_id)
233 .header("x-amz-request-id", request_id)
234 .body(Body::from(body))
235 .unwrap()
236}
237
238trait ProtocolExt {
239 fn error_status(&self) -> StatusCode;
240}
241
242impl ProtocolExt for AwsProtocol {
243 fn error_status(&self) -> StatusCode {
244 StatusCode::BAD_REQUEST
245 }
246}