1use axum::body::Body;
2use axum::extract::{Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use crate::auth::{is_root_bypass, CredentialResolver, IamMode, IamPolicyEvaluator};
9use crate::protocol::{self, AwsProtocol};
10use crate::registry::ServiceRegistry;
11use crate::service::{AwsRequest, ResponseBody};
12
13pub async fn dispatch(
15 Extension(registry): Extension<Arc<ServiceRegistry>>,
16 Extension(config): Extension<Arc<DispatchConfig>>,
17 Query(query_params): Query<HashMap<String, String>>,
18 request: Request<Body>,
19) -> Response<Body> {
20 let request_id = uuid::Uuid::new_v4().to_string();
21
22 let (parts, body) = request.into_parts();
23 const MAX_BODY_BYTES: usize = 128 * 1024 * 1024;
29 let body_bytes = match axum::body::to_bytes(body, MAX_BODY_BYTES).await {
30 Ok(b) => b,
31 Err(_) => {
32 return build_error_response(
33 StatusCode::PAYLOAD_TOO_LARGE,
34 "RequestEntityTooLarge",
35 "Request body too large",
36 &request_id,
37 AwsProtocol::Query,
38 );
39 }
40 };
41
42 let detected = match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
44 Some(d) => d,
45 None => {
46 if parts.method == http::Method::OPTIONS {
52 protocol::DetectedRequest {
53 service: "s3".to_string(),
54 action: String::new(),
55 protocol: AwsProtocol::Rest,
56 }
57 } else if !parts.uri.path().starts_with("/_") {
58 protocol::DetectedRequest {
63 service: "apigateway".to_string(),
64 action: String::new(),
65 protocol: AwsProtocol::RestJson,
66 }
67 } else {
68 return build_error_response(
69 StatusCode::BAD_REQUEST,
70 "MissingAction",
71 "Could not determine target service or action from request",
72 &request_id,
73 AwsProtocol::Query,
74 );
75 }
76 }
77 };
78
79 let service = match registry.get(&detected.service) {
81 Some(s) => s,
82 None => {
83 return build_error_response(
84 detected.protocol.error_status(),
85 "UnknownService",
86 &format!("Service '{}' is not available", detected.service),
87 &request_id,
88 detected.protocol,
89 );
90 }
91 };
92
93 let auth_header = parts
95 .headers
96 .get("authorization")
97 .and_then(|v| v.to_str().ok())
98 .unwrap_or("");
99 let header_info = fakecloud_aws::sigv4::parse_sigv4(auth_header);
100 let presigned_info = if header_info.is_none() {
101 fakecloud_aws::sigv4::parse_sigv4_presigned(&query_params).map(|p| p.as_info())
103 } else {
104 None
105 };
106 let sigv4_info = header_info.or(presigned_info);
107 let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
108 let region = sigv4_info
109 .map(|info| info.region)
110 .or_else(|| extract_region_from_user_agent(&parts.headers))
111 .unwrap_or_else(|| config.region.clone());
112
113 let caller_akid = access_key_id.as_deref().unwrap_or("");
119 let resolved = if !caller_akid.is_empty() && !is_root_bypass(caller_akid) {
120 config
121 .credential_resolver
122 .as_ref()
123 .and_then(|r| r.resolve(caller_akid))
124 } else {
125 None
126 };
127 let caller_principal = resolved.as_ref().map(|r| r.principal.clone());
128
129 if config.verify_sigv4 && !is_root_bypass(caller_akid) && config.credential_resolver.is_some() {
134 let amz_date = parts
135 .headers
136 .get("x-amz-date")
137 .and_then(|v| v.to_str().ok());
138 let parsed = fakecloud_aws::sigv4::parse_sigv4_header(auth_header, amz_date)
139 .or_else(|| fakecloud_aws::sigv4::parse_sigv4_presigned(&query_params));
140 let parsed = match parsed {
141 Some(p) => p,
142 None => {
143 return build_error_response(
144 StatusCode::FORBIDDEN,
145 "IncompleteSignature",
146 "Request is missing or has a malformed AWS Signature",
147 &request_id,
148 detected.protocol,
149 );
150 }
151 };
152 let resolved_for_verify = match resolved.as_ref() {
153 Some(r) => r,
154 None => {
155 return build_error_response(
156 StatusCode::FORBIDDEN,
157 "InvalidClientTokenId",
158 "The security token included in the request is invalid",
159 &request_id,
160 detected.protocol,
161 );
162 }
163 };
164 let headers_vec = fakecloud_aws::sigv4::headers_from_http(&parts.headers);
165 let raw_query_for_verify = parts.uri.query().unwrap_or("").to_string();
166 let verify_req = fakecloud_aws::sigv4::VerifyRequest {
167 method: parts.method.as_str(),
168 path: parts.uri.path(),
169 query: &raw_query_for_verify,
170 headers: &headers_vec,
171 body: &body_bytes,
172 };
173 match fakecloud_aws::sigv4::verify(
174 &parsed,
175 &verify_req,
176 &resolved_for_verify.secret_access_key,
177 chrono::Utc::now(),
178 ) {
179 Ok(()) => {}
180 Err(fakecloud_aws::sigv4::SigV4Error::RequestTimeTooSkewed { .. }) => {
181 return build_error_response(
182 StatusCode::FORBIDDEN,
183 "RequestTimeTooSkewed",
184 "The difference between the request time and the current time is too large",
185 &request_id,
186 detected.protocol,
187 );
188 }
189 Err(fakecloud_aws::sigv4::SigV4Error::InvalidDate(msg)) => {
190 return build_error_response(
191 StatusCode::FORBIDDEN,
192 "IncompleteSignature",
193 &format!("Invalid x-amz-date: {msg}"),
194 &request_id,
195 detected.protocol,
196 );
197 }
198 Err(fakecloud_aws::sigv4::SigV4Error::Malformed(msg)) => {
199 return build_error_response(
200 StatusCode::FORBIDDEN,
201 "IncompleteSignature",
202 &format!("Malformed SigV4 signature: {msg}"),
203 &request_id,
204 detected.protocol,
205 );
206 }
207 Err(fakecloud_aws::sigv4::SigV4Error::SignatureMismatch) => {
208 return build_error_response(
209 StatusCode::FORBIDDEN,
210 "SignatureDoesNotMatch",
211 "The request signature we calculated does not match the signature you provided",
212 &request_id,
213 detected.protocol,
214 );
215 }
216 }
217 }
218
219 let path = parts.uri.path().to_string();
221 let raw_query = parts.uri.query().unwrap_or("").to_string();
222 let path_segments: Vec<String> = path
223 .split('/')
224 .filter(|s| !s.is_empty())
225 .map(|s| s.to_string())
226 .collect();
227
228 if detected.protocol == AwsProtocol::Json
230 && !body_bytes.is_empty()
231 && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
232 {
233 return build_error_response(
234 StatusCode::BAD_REQUEST,
235 "SerializationException",
236 "Start of structure or map found where not expected",
237 &request_id,
238 AwsProtocol::Json,
239 );
240 }
241
242 let mut all_params = query_params;
244 if detected.protocol == AwsProtocol::Query {
245 let body_params = protocol::parse_query_body(&body_bytes);
246 for (k, v) in body_params {
247 all_params.entry(k).or_insert(v);
248 }
249 }
250
251 let aws_request = AwsRequest {
252 service: detected.service.clone(),
253 action: detected.action.clone(),
254 region,
255 account_id: config.account_id.clone(),
256 request_id: request_id.clone(),
257 headers: parts.headers,
258 query_params: all_params,
259 body: body_bytes,
260 path_segments,
261 raw_path: path,
262 raw_query,
263 method: parts.method,
264 is_query_protocol: detected.protocol == AwsProtocol::Query,
265 access_key_id,
266 principal: caller_principal,
267 };
268
269 tracing::info!(
270 service = %aws_request.service,
271 action = %aws_request.action,
272 request_id = %aws_request.request_id,
273 "handling request"
274 );
275
276 if config.iam_mode.is_enabled()
283 && service.iam_enforceable()
284 && !is_root_bypass(aws_request.access_key_id.as_deref().unwrap_or(""))
285 {
286 if let Some(evaluator) = config.policy_evaluator.as_ref() {
287 if let Some(principal) = aws_request.principal.as_ref() {
288 if !principal.is_root() {
289 if let Some(iam_action) = service.iam_action_for(&aws_request) {
290 let decision = evaluator.evaluate(principal, &iam_action);
291 if !decision.is_allow() {
292 tracing::warn!(
293 target: "fakecloud::iam::audit",
294 service = %detected.service,
295 action = %iam_action.action_string(),
296 resource = %iam_action.resource,
297 principal = %principal.arn,
298 decision = ?decision,
299 mode = %config.iam_mode,
300 request_id = %request_id,
301 "IAM policy evaluation denied request"
302 );
303 if config.iam_mode.is_strict() {
304 return build_error_response(
305 StatusCode::FORBIDDEN,
306 "AccessDeniedException",
307 &format!(
308 "User: {} is not authorized to perform: {} on resource: {}",
309 principal.arn,
310 iam_action.action_string(),
311 iam_action.resource
312 ),
313 &request_id,
314 detected.protocol,
315 );
316 }
317 }
320 } else {
321 tracing::warn!(
326 target: "fakecloud::iam::audit",
327 service = %detected.service,
328 action = %aws_request.action,
329 "service is iam_enforceable but has no IamAction mapping for this action; skipping evaluation"
330 );
331 }
332 }
333 }
334 }
335 }
336
337 match service.handle(aws_request).await {
338 Ok(resp) => {
339 let mut builder = Response::builder()
340 .status(resp.status)
341 .header("x-amzn-requestid", &request_id)
342 .header("x-amz-request-id", &request_id);
343
344 if !resp.content_type.is_empty() {
345 builder = builder.header("content-type", &resp.content_type);
346 }
347
348 let has_content_length = resp
349 .headers
350 .iter()
351 .any(|(k, _)| k.as_str().eq_ignore_ascii_case("content-length"));
352
353 for (k, v) in &resp.headers {
354 builder = builder.header(k, v);
355 }
356
357 match resp.body {
358 ResponseBody::Bytes(b) => builder.body(Body::from(b)).unwrap(),
359 ResponseBody::File { file, size } => {
360 let stream = tokio_util::io::ReaderStream::new(file);
361 let body = Body::from_stream(stream);
362 if !has_content_length {
363 builder = builder.header("content-length", size.to_string());
364 }
365 builder.body(body).unwrap()
366 }
367 }
368 }
369 Err(err) => {
370 tracing::warn!(
371 service = %detected.service,
372 action = %detected.action,
373 error = %err,
374 "request failed"
375 );
376 let error_headers = err.response_headers().to_vec();
377 let mut resp = build_error_response_with_fields(
378 err.status(),
379 err.code(),
380 &err.message(),
381 &request_id,
382 detected.protocol,
383 err.extra_fields(),
384 );
385 for (k, v) in &error_headers {
386 if let (Ok(name), Ok(val)) = (
387 k.parse::<http::header::HeaderName>(),
388 v.parse::<http::header::HeaderValue>(),
389 ) {
390 resp.headers_mut().insert(name, val);
391 }
392 }
393 resp
394 }
395 }
396}
397
398#[derive(Clone)]
400pub struct DispatchConfig {
401 pub region: String,
402 pub account_id: String,
403 pub verify_sigv4: bool,
407 pub iam_mode: IamMode,
412 pub credential_resolver: Option<Arc<dyn CredentialResolver>>,
416 pub policy_evaluator: Option<Arc<dyn IamPolicyEvaluator>>,
420}
421
422impl std::fmt::Debug for DispatchConfig {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 f.debug_struct("DispatchConfig")
425 .field("region", &self.region)
426 .field("account_id", &self.account_id)
427 .field("verify_sigv4", &self.verify_sigv4)
428 .field("iam_mode", &self.iam_mode)
429 .field(
430 "credential_resolver",
431 &self
432 .credential_resolver
433 .as_ref()
434 .map(|_| "<CredentialResolver>"),
435 )
436 .field(
437 "policy_evaluator",
438 &self
439 .policy_evaluator
440 .as_ref()
441 .map(|_| "<IamPolicyEvaluator>"),
442 )
443 .finish()
444 }
445}
446
447impl DispatchConfig {
448 pub fn new(region: impl Into<String>, account_id: impl Into<String>) -> Self {
451 Self {
452 region: region.into(),
453 account_id: account_id.into(),
454 verify_sigv4: false,
455 iam_mode: IamMode::Off,
456 credential_resolver: None,
457 policy_evaluator: None,
458 }
459 }
460}
461
462fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
464 let ua = headers.get("user-agent")?.to_str().ok()?;
465 for part in ua.split_whitespace() {
466 if let Some(region) = part.strip_prefix("region/") {
467 if !region.is_empty() {
468 return Some(region.to_string());
469 }
470 }
471 }
472 None
473}
474
475fn build_error_response(
476 status: StatusCode,
477 code: &str,
478 message: &str,
479 request_id: &str,
480 protocol: AwsProtocol,
481) -> Response<Body> {
482 build_error_response_with_fields(status, code, message, request_id, protocol, &[])
483}
484
485fn build_error_response_with_fields(
486 status: StatusCode,
487 code: &str,
488 message: &str,
489 request_id: &str,
490 protocol: AwsProtocol,
491 extra_fields: &[(String, String)],
492) -> Response<Body> {
493 let (status, content_type, body) = match protocol {
494 AwsProtocol::Query => {
495 fakecloud_aws::error::xml_error_response(status, code, message, request_id)
496 }
497 AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
498 status,
499 code,
500 message,
501 request_id,
502 extra_fields,
503 ),
504 AwsProtocol::Json | AwsProtocol::RestJson => {
505 fakecloud_aws::error::json_error_response(status, code, message)
506 }
507 };
508
509 Response::builder()
510 .status(status)
511 .header("content-type", content_type)
512 .header("x-amzn-requestid", request_id)
513 .header("x-amz-request-id", request_id)
514 .body(Body::from(body))
515 .unwrap()
516}
517
518trait ProtocolExt {
519 fn error_status(&self) -> StatusCode;
520}
521
522impl ProtocolExt for AwsProtocol {
523 fn error_status(&self) -> StatusCode {
524 StatusCode::BAD_REQUEST
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn dispatch_config_new_defaults_to_off() {
534 let cfg = DispatchConfig::new("us-east-1", "123456789012");
535 assert_eq!(cfg.region, "us-east-1");
536 assert_eq!(cfg.account_id, "123456789012");
537 assert!(!cfg.verify_sigv4);
538 assert_eq!(cfg.iam_mode, IamMode::Off);
539 }
540
541 #[test]
542 fn dispatch_config_carries_opt_in_flags() {
543 let cfg = DispatchConfig {
544 region: "eu-west-1".to_string(),
545 account_id: "000000000000".to_string(),
546 verify_sigv4: true,
547 iam_mode: IamMode::Strict,
548 credential_resolver: None,
549 policy_evaluator: None,
550 };
551 assert!(cfg.verify_sigv4);
552 assert!(cfg.iam_mode.is_strict());
553 }
554}