1use axum::body::Body;
2use axum::extract::{Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::protocol::{self, AwsProtocol};
9use crate::registry::ServiceRegistry;
10use crate::service::AwsRequest;
11
12pub async fn dispatch(
14 Extension(registry): Extension<Arc<ServiceRegistry>>,
15 Extension(config): Extension<Arc<DispatchConfig>>,
16 Query(query_params): Query<HashMap<String, String>>,
17 request: Request<Body>,
18) -> Response<Body> {
19 let request_id = uuid::Uuid::new_v4().to_string();
20
21 let (parts, body) = request.into_parts();
22 let body_bytes = match axum::body::to_bytes(body, 10 * 1024 * 1024).await {
23 Ok(b) => b,
24 Err(_) => {
25 return build_error_response(
26 StatusCode::PAYLOAD_TOO_LARGE,
27 "RequestEntityTooLarge",
28 "Request body too large",
29 &request_id,
30 AwsProtocol::Query,
31 );
32 }
33 };
34
35 let detected = match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
37 Some(d) => d,
38 None => {
39 if parts.method == http::Method::OPTIONS {
42 protocol::DetectedRequest {
43 service: "s3".to_string(),
44 action: String::new(),
45 protocol: AwsProtocol::Rest,
46 }
47 } else {
48 return build_error_response(
49 StatusCode::BAD_REQUEST,
50 "MissingAction",
51 "Could not determine target service or action from request",
52 &request_id,
53 AwsProtocol::Query,
54 );
55 }
56 }
57 };
58
59 let service = match registry.get(&detected.service) {
61 Some(s) => s,
62 None => {
63 return build_error_response(
64 detected.protocol.error_status(),
65 "UnknownService",
66 &format!("Service '{}' is not available", detected.service),
67 &request_id,
68 detected.protocol,
69 );
70 }
71 };
72
73 let sigv4_info = fakecloud_aws::sigv4::parse_sigv4(
75 parts
76 .headers
77 .get("authorization")
78 .and_then(|v| v.to_str().ok())
79 .unwrap_or(""),
80 );
81 let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
82 let region = sigv4_info
83 .map(|info| info.region)
84 .or_else(|| extract_region_from_user_agent(&parts.headers))
85 .unwrap_or_else(|| config.region.clone());
86
87 let path = parts.uri.path().to_string();
89 let raw_query = parts.uri.query().unwrap_or("").to_string();
90 let path_segments: Vec<String> = path
91 .split('/')
92 .filter(|s| !s.is_empty())
93 .map(|s| s.to_string())
94 .collect();
95
96 if detected.protocol == AwsProtocol::Json
98 && !body_bytes.is_empty()
99 && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
100 {
101 return build_error_response(
102 StatusCode::BAD_REQUEST,
103 "SerializationException",
104 "Start of structure or map found where not expected",
105 &request_id,
106 AwsProtocol::Json,
107 );
108 }
109
110 let mut all_params = query_params;
112 if detected.protocol == AwsProtocol::Query {
113 let body_params = protocol::parse_query_body(&body_bytes);
114 for (k, v) in body_params {
115 all_params.entry(k).or_insert(v);
116 }
117 }
118
119 let aws_request = AwsRequest {
120 service: detected.service.clone(),
121 action: detected.action.clone(),
122 region,
123 account_id: config.account_id.clone(),
124 request_id: request_id.clone(),
125 headers: parts.headers,
126 query_params: all_params,
127 body: body_bytes,
128 path_segments,
129 raw_path: path,
130 raw_query,
131 method: parts.method,
132 is_query_protocol: detected.protocol == AwsProtocol::Query,
133 access_key_id,
134 };
135
136 tracing::info!(
137 service = %aws_request.service,
138 action = %aws_request.action,
139 request_id = %aws_request.request_id,
140 "handling request"
141 );
142
143 match service.handle(aws_request).await {
144 Ok(resp) => {
145 let mut builder = Response::builder()
146 .status(resp.status)
147 .header("x-amzn-requestid", &request_id)
148 .header("x-amz-request-id", &request_id);
149
150 if !resp.content_type.is_empty() {
151 builder = builder.header("content-type", &resp.content_type);
152 }
153
154 for (k, v) in &resp.headers {
155 builder = builder.header(k, v);
156 }
157
158 builder.body(Body::from(resp.body)).unwrap()
159 }
160 Err(err) => {
161 tracing::warn!(
162 service = %detected.service,
163 action = %detected.action,
164 error = %err,
165 "request failed"
166 );
167 let error_headers = err.response_headers().to_vec();
168 let mut resp = build_error_response_with_fields(
169 err.status(),
170 err.code(),
171 &err.message(),
172 &request_id,
173 detected.protocol,
174 err.extra_fields(),
175 );
176 for (k, v) in &error_headers {
177 if let (Ok(name), Ok(val)) = (
178 k.parse::<http::header::HeaderName>(),
179 v.parse::<http::header::HeaderValue>(),
180 ) {
181 resp.headers_mut().insert(name, val);
182 }
183 }
184 resp
185 }
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct DispatchConfig {
192 pub region: String,
193 pub account_id: String,
194}
195
196fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
198 let ua = headers.get("user-agent")?.to_str().ok()?;
199 for part in ua.split_whitespace() {
200 if let Some(region) = part.strip_prefix("region/") {
201 if !region.is_empty() {
202 return Some(region.to_string());
203 }
204 }
205 }
206 None
207}
208
209fn build_error_response(
210 status: StatusCode,
211 code: &str,
212 message: &str,
213 request_id: &str,
214 protocol: AwsProtocol,
215) -> Response<Body> {
216 build_error_response_with_fields(status, code, message, request_id, protocol, &[])
217}
218
219fn build_error_response_with_fields(
220 status: StatusCode,
221 code: &str,
222 message: &str,
223 request_id: &str,
224 protocol: AwsProtocol,
225 extra_fields: &[(String, String)],
226) -> Response<Body> {
227 let (status, content_type, body) = match protocol {
228 AwsProtocol::Query => {
229 fakecloud_aws::error::xml_error_response(status, code, message, request_id)
230 }
231 AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
232 status,
233 code,
234 message,
235 request_id,
236 extra_fields,
237 ),
238 AwsProtocol::Json | AwsProtocol::RestJson => {
239 fakecloud_aws::error::json_error_response(status, code, message)
240 }
241 };
242
243 Response::builder()
244 .status(status)
245 .header("content-type", content_type)
246 .header("x-amzn-requestid", request_id)
247 .header("x-amz-request-id", request_id)
248 .body(Body::from(body))
249 .unwrap()
250}
251
252trait ProtocolExt {
253 fn error_status(&self) -> StatusCode;
254}
255
256impl ProtocolExt for AwsProtocol {
257 fn error_status(&self) -> StatusCode {
258 StatusCode::BAD_REQUEST
259 }
260}