1use axum::body::Body;
2use axum::extract::{ConnectInfo, Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use bytes::Bytes;
6use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9
10use crate::auth::{
11 is_root_bypass, ConditionContext, CredentialResolver, IamMode, IamPolicyEvaluator, Principal,
12 PrincipalType, ResourcePolicyProvider,
13};
14use crate::protocol::{self, AwsProtocol};
15use crate::registry::ServiceRegistry;
16use crate::service::{AwsRequest, ResponseBody};
17
18pub async fn dispatch(
20 ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
21 Extension(registry): Extension<Arc<ServiceRegistry>>,
22 Extension(config): Extension<Arc<DispatchConfig>>,
23 Query(query_params): Query<HashMap<String, String>>,
24 request: Request<Body>,
25) -> Response<Body> {
26 let remote_addr = Some(remote_addr);
27 let request_id = uuid::Uuid::new_v4().to_string();
28
29 let (parts, body) = request.into_parts();
30
31 let stream_route = streaming_route(
37 &parts.method,
38 parts.uri.path(),
39 &parts.headers,
40 &query_params,
41 );
42 let header_only = protocol::detect_service_headers_only(&parts.headers, &query_params);
43 let stream_dispatch = match (&stream_route, &header_only) {
44 (Some(sr), Some(detected)) if sr.0 == detected.service => Some(detected.clone()),
47 (Some((service, _)), None) if *service == "ecr" => Some(protocol::DetectedRequest {
53 service: "ecr".to_string(),
54 action: String::new(),
55 protocol: AwsProtocol::Rest,
56 }),
57 _ => None,
58 };
59
60 let (body_bytes, body_stream) = if stream_dispatch.is_some() {
61 (Bytes::new(), Some(body))
62 } else {
63 let max_body_bytes = max_request_body_bytes();
68 match axum::body::to_bytes(body, max_body_bytes).await {
69 Ok(b) => (b, None),
70 Err(_) => {
71 return build_error_response(
72 StatusCode::PAYLOAD_TOO_LARGE,
73 "RequestEntityTooLarge",
74 "Request body too large",
75 &request_id,
76 AwsProtocol::Query,
77 );
78 }
79 }
80 };
81
82 let detected = if let Some(d) = stream_dispatch {
84 d
85 } else {
86 match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
87 Some(d) => d,
88 None => {
89 if parts.method == http::Method::OPTIONS {
95 protocol::DetectedRequest {
96 service: "s3".to_string(),
97 action: String::new(),
98 protocol: AwsProtocol::Rest,
99 }
100 } else if parts.uri.path() == "/v2" || parts.uri.path().starts_with("/v2/") {
101 protocol::DetectedRequest {
105 service: "ecr".to_string(),
106 action: String::new(),
107 protocol: AwsProtocol::Rest,
108 }
109 } else if !parts.uri.path().starts_with("/_") {
110 protocol::DetectedRequest {
115 service: "apigateway".to_string(),
116 action: String::new(),
117 protocol: AwsProtocol::RestJson,
118 }
119 } else {
120 return build_error_response(
121 StatusCode::BAD_REQUEST,
122 "MissingAction",
123 "Could not determine target service or action from request",
124 &request_id,
125 AwsProtocol::Query,
126 );
127 }
128 }
129 }
130 };
131
132 let detected = if detected.service == "bedrock" {
136 let first_seg = parts.uri.path().split('/').nth(1);
137 if matches!(
138 first_seg,
139 Some(
140 "agents"
141 | "knowledgebases"
142 | "flows"
143 | "prompts"
144 | "tags"
145 | "retrieveAndGenerate"
146 | "retrieveAndGenerateStream"
147 | "optimize-prompt"
148 | "sessions"
149 | "invocations"
150 | "generate-query"
151 | "rerank"
152 )
153 ) {
154 let segs: Vec<&str> = parts.uri.path().split('/').collect();
156 let is_runtime = matches!(
157 segs.as_slice(),
158 ["", "agents", _, "agentAliases", _, ..] | ["", "flows", _, "aliases", _] | ["", "knowledgebases", _, "retrieve"] | ["", "retrieveAndGenerate"]
162 | ["", "retrieveAndGenerateStream"]
163 | ["", "optimize-prompt"]
164 | ["", "sessions", ..]
165 | ["", "invocations", ..]
166 | ["", "generate-query"]
167 | ["", "rerank"]
168 );
169 if is_runtime {
170 protocol::DetectedRequest {
171 service: "bedrock-agent-runtime".to_string(),
172 ..detected
173 }
174 } else {
175 protocol::DetectedRequest {
176 service: "bedrock-agent".to_string(),
177 ..detected
178 }
179 }
180 } else {
181 detected
182 }
183 } else {
184 detected
185 };
186
187 let service = match registry.get(&detected.service) {
189 Some(s) => s,
190 None => {
191 return build_error_response(
192 detected.protocol.error_status(),
193 "UnknownService",
194 &format!("Service '{}' is not available", detected.service),
195 &request_id,
196 detected.protocol,
197 );
198 }
199 };
200
201 let auth_header = parts
203 .headers
204 .get("authorization")
205 .and_then(|v| v.to_str().ok())
206 .unwrap_or("");
207 let header_info = fakecloud_aws::sigv4::parse_sigv4(auth_header);
208 let presigned_info = if header_info.is_none() {
209 fakecloud_aws::sigv4::parse_sigv4_presigned(&query_params).map(|p| p.as_info())
211 } else {
212 None
213 };
214 let sigv4_info = header_info.or(presigned_info);
215 let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
216
217 let host_info = protocol::parse_routing_host_from_headers(&parts.headers);
223
224 let region = sigv4_info
225 .map(|info| info.region)
226 .or_else(|| host_info.as_ref().map(|h| h.region.clone()))
227 .or_else(|| extract_region_from_user_agent(&parts.headers))
228 .unwrap_or_else(|| config.region.clone());
229
230 let caller_akid = access_key_id.as_deref().unwrap_or("");
236 let resolved = if !caller_akid.is_empty() && !is_root_bypass(caller_akid) {
237 config
238 .credential_resolver
239 .as_ref()
240 .and_then(|r| r.resolve(caller_akid))
241 } else {
242 None
243 };
244 let caller_principal = resolved.as_ref().map(|r| r.principal.clone());
245 let caller_session_policies = resolved
246 .as_ref()
247 .map(|r| r.session_policies.clone())
248 .unwrap_or_default();
249
250 if config.verify_sigv4 && !is_root_bypass(caller_akid) && config.credential_resolver.is_some() {
255 let amz_date = parts
256 .headers
257 .get("x-amz-date")
258 .and_then(|v| v.to_str().ok());
259 let parsed = fakecloud_aws::sigv4::parse_sigv4_header(auth_header, amz_date)
260 .or_else(|| fakecloud_aws::sigv4::parse_sigv4_presigned(&query_params));
261 let parsed = match parsed {
262 Some(p) => p,
263 None => {
264 return build_error_response(
265 StatusCode::FORBIDDEN,
266 "IncompleteSignature",
267 "Request is missing or has a malformed AWS Signature",
268 &request_id,
269 detected.protocol,
270 );
271 }
272 };
273 let resolved_for_verify = match resolved.as_ref() {
274 Some(r) => r,
275 None => {
276 return build_error_response(
277 StatusCode::FORBIDDEN,
278 "InvalidClientTokenId",
279 "The security token included in the request is invalid",
280 &request_id,
281 detected.protocol,
282 );
283 }
284 };
285 let headers_vec = fakecloud_aws::sigv4::headers_from_http(&parts.headers);
286 let raw_query_for_verify = parts.uri.query().unwrap_or("").to_string();
287 let verify_req = fakecloud_aws::sigv4::VerifyRequest {
288 method: parts.method.as_str(),
289 path: parts.uri.path(),
290 query: &raw_query_for_verify,
291 headers: &headers_vec,
292 body: &body_bytes,
293 };
294 match fakecloud_aws::sigv4::verify(
295 &parsed,
296 &verify_req,
297 &resolved_for_verify.secret_access_key,
298 chrono::Utc::now(),
299 ) {
300 Ok(()) => {}
301 Err(fakecloud_aws::sigv4::SigV4Error::RequestTimeTooSkewed { .. }) => {
302 return build_error_response(
303 StatusCode::FORBIDDEN,
304 "RequestTimeTooSkewed",
305 "The difference between the request time and the current time is too large",
306 &request_id,
307 detected.protocol,
308 );
309 }
310 Err(fakecloud_aws::sigv4::SigV4Error::InvalidDate(msg)) => {
311 return build_error_response(
312 StatusCode::FORBIDDEN,
313 "IncompleteSignature",
314 &format!("Invalid x-amz-date: {msg}"),
315 &request_id,
316 detected.protocol,
317 );
318 }
319 Err(fakecloud_aws::sigv4::SigV4Error::Malformed(msg)) => {
320 return build_error_response(
321 StatusCode::FORBIDDEN,
322 "IncompleteSignature",
323 &format!("Malformed SigV4 signature: {msg}"),
324 &request_id,
325 detected.protocol,
326 );
327 }
328 Err(fakecloud_aws::sigv4::SigV4Error::SignatureMismatch) => {
329 return build_error_response(
330 StatusCode::FORBIDDEN,
331 "SignatureDoesNotMatch",
332 "The request signature we calculated does not match the signature you provided",
333 &request_id,
334 detected.protocol,
335 );
336 }
337 Err(fakecloud_aws::sigv4::SigV4Error::PresignedUrlExpired { .. }) => {
338 return build_error_response(
339 StatusCode::FORBIDDEN,
340 "AccessDenied",
341 "Request has expired",
342 &request_id,
343 detected.protocol,
344 );
345 }
346 Err(fakecloud_aws::sigv4::SigV4Error::InvalidPresignExpires(_)) => {
347 return build_error_response(
348 StatusCode::BAD_REQUEST,
349 "AuthorizationQueryParametersError",
350 "X-Amz-Expires must be a number between 1 and 604800 seconds",
351 &request_id,
352 detected.protocol,
353 );
354 }
355 }
356 }
357
358 let wire_path = parts.uri.path();
363 let path = if detected.service == "s3" {
364 if let Some(bucket) = host_info.as_ref().and_then(|h| h.bucket.as_deref()) {
365 let prefix_with_slash = format!("/{bucket}/");
366 let is_bucket_root = wire_path.trim_end_matches('/') == format!("/{bucket}");
367 if wire_path.starts_with(&prefix_with_slash) || is_bucket_root {
368 wire_path.to_string()
369 } else if wire_path == "/" || wire_path.is_empty() {
370 format!("/{bucket}")
371 } else {
372 format!("/{bucket}{wire_path}")
373 }
374 } else {
375 wire_path.to_string()
376 }
377 } else {
378 wire_path.to_string()
379 };
380 let raw_query = parts.uri.query().unwrap_or("").to_string();
381 let path_segments: Vec<String> = path
382 .split('/')
383 .filter(|s| !s.is_empty())
384 .map(|s| s.to_string())
385 .collect();
386
387 if detected.protocol == AwsProtocol::Json
389 && !body_bytes.is_empty()
390 && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
391 {
392 return build_error_response(
393 StatusCode::BAD_REQUEST,
394 "SerializationException",
395 "Start of structure or map found where not expected",
396 &request_id,
397 AwsProtocol::Json,
398 );
399 }
400
401 let mut all_params = query_params;
403 if detected.protocol == AwsProtocol::Query {
404 let body_params = protocol::parse_query_body(&body_bytes);
405 for (k, v) in body_params {
406 all_params.entry(k).or_insert(v);
407 }
408 }
409
410 let aws_request = AwsRequest {
411 service: detected.service.clone(),
412 action: detected.action.clone(),
413 region,
414 account_id: caller_principal
415 .as_ref()
416 .map(|p| p.account_id.clone())
417 .unwrap_or_else(|| config.account_id.clone()),
418 request_id: request_id.clone(),
419 headers: parts.headers,
420 query_params: all_params,
421 body: body_bytes,
422 body_stream: parking_lot::Mutex::new(body_stream),
423 path_segments,
424 raw_path: path,
425 raw_query,
426 method: parts.method,
427 is_query_protocol: detected.protocol == AwsProtocol::Query,
428 access_key_id,
429 principal: caller_principal,
430 };
431
432 tracing::info!(
433 service = %aws_request.service,
434 action = %aws_request.action,
435 request_id = %aws_request.request_id,
436 "handling request"
437 );
438
439 if config.iam_mode.is_enabled()
446 && service.iam_enforceable()
447 && !is_root_bypass(aws_request.access_key_id.as_deref().unwrap_or(""))
448 {
449 if let Some(evaluator) = config.policy_evaluator.as_ref() {
450 if let Some(principal) = aws_request.principal.as_ref() {
451 if !principal.is_root() {
452 if let Some(iam_action) = service.iam_action_for(&aws_request) {
453 let mut condition_context = build_condition_context(
454 principal,
455 remote_addr,
456 &aws_request.region,
457 is_secure_transport(&aws_request.headers),
458 );
459 if let Some(rc) = resolved.as_ref() {
467 condition_context.aws_mfa_present = Some(rc.mfa_present);
468 condition_context.aws_token_issue_time = rc.token_issued_at;
469 condition_context.aws_federated_provider =
470 rc.federated_provider.clone();
471 if rc.mfa_present {
479 if let Some(issued) = rc.token_issued_at {
480 let age = chrono::Utc::now()
481 .signed_duration_since(issued)
482 .num_seconds()
483 .max(0);
484 condition_context.aws_mfa_age_seconds = Some(age);
485 }
486 }
487 }
488 condition_context.service_keys =
489 service.iam_condition_keys_for(&aws_request, &iam_action);
490
491 match service.resource_tags_for(&iam_action.resource) {
494 Some(tags) => condition_context.resource_tags = Some(tags),
495 None => tracing::debug!(
496 target: "fakecloud::iam::audit",
497 service = %detected.service,
498 resource = %iam_action.resource,
499 "service does not expose resource tags for ABAC; skipping aws:ResourceTag/* evaluation"
500 ),
501 }
502 match service.request_tags_from(&aws_request, iam_action.action) {
504 Some(tags) => condition_context.request_tags = Some(tags),
505 None => tracing::debug!(
506 target: "fakecloud::iam::audit",
507 service = %detected.service,
508 action = %iam_action.action_string(),
509 "service does not expose request tags for ABAC; skipping aws:RequestTag/* / aws:TagKeys evaluation"
510 ),
511 }
512 condition_context.principal_tags = principal.tags.clone();
514
515 let resource_policy_json =
524 config.resource_policy_provider.as_ref().and_then(|p| {
525 p.resource_policy(&detected.service, &iam_action.resource)
526 });
527 let resource_account_id = config
537 .resource_policy_provider
538 .as_ref()
539 .and_then(|p| {
540 p.resource_owner_account(&detected.service, &iam_action.resource)
541 })
542 .or_else(|| parse_account_from_arn(&iam_action.resource))
543 .unwrap_or_else(|| principal.account_id.clone());
544 let scps = config
551 .scp_resolver
552 .as_ref()
553 .and_then(|r| r.scps_for(principal));
554 let decision = evaluator.evaluate_with_resource_policy(
555 principal,
556 &iam_action,
557 &condition_context,
558 resource_policy_json.as_deref(),
559 &resource_account_id,
560 &caller_session_policies,
561 scps.as_deref(),
562 );
563 if !decision.is_allow() {
564 tracing::warn!(
565 target: "fakecloud::iam::audit",
566 service = %detected.service,
567 action = %iam_action.action_string(),
568 resource = %iam_action.resource,
569 principal = %principal.arn,
570 resource_policy_present = resource_policy_json.is_some(),
571 decision = ?decision,
572 mode = %config.iam_mode,
573 request_id = %request_id,
574 "IAM policy evaluation denied request"
575 );
576 if config.iam_mode.is_strict() {
577 let context_summary = serde_json::json!({
590 "aws:PrincipalArn": principal.arn,
591 "aws:PrincipalAccount": principal.account_id,
592 "aws:RequestedRegion": condition_context
593 .aws_requested_region
594 .clone()
595 .unwrap_or_default(),
596 "aws:SecureTransport": condition_context
597 .aws_secure_transport
598 .unwrap_or(false),
599 "aws:Action": iam_action.action_string(),
600 "aws:Resource": iam_action.resource,
601 "decision": format!("{:?}", decision),
602 });
603 let action_string = iam_action.action_string();
604 let encoded = crate::auth_message::encode_deny(
605 matches!(decision, crate::auth::IamDecision::ExplicitDeny),
606 Some(&action_string),
607 Some(&principal.arn),
608 Vec::new(),
609 Some(context_summary),
610 );
611 return build_error_response(
612 StatusCode::FORBIDDEN,
613 "AccessDeniedException",
614 &format!(
615 "User: {} is not authorized to perform: {} on resource: {} Encoded authorization failure message: {}",
616 principal.arn,
617 iam_action.action_string(),
618 iam_action.resource,
619 encoded,
620 ),
621 &request_id,
622 detected.protocol,
623 );
624 }
625 }
628 } else {
629 tracing::warn!(
634 target: "fakecloud::iam::audit",
635 service = %detected.service,
636 action = %aws_request.action,
637 "service is iam_enforceable but has no IamAction mapping for this action; skipping evaluation"
638 );
639 }
640 }
641 }
642 }
643 }
644
645 match service.handle(aws_request).await {
646 Ok(resp) => {
647 let mut builder = Response::builder()
648 .status(resp.status)
649 .header("x-amzn-requestid", &request_id)
650 .header("x-amz-request-id", &request_id);
651
652 if !resp.content_type.is_empty() {
653 builder = builder.header("content-type", &resp.content_type);
654 }
655
656 let has_content_length = resp
657 .headers
658 .iter()
659 .any(|(k, _)| k.as_str().eq_ignore_ascii_case("content-length"));
660
661 for (k, v) in &resp.headers {
662 builder = builder.header(k, v);
663 }
664
665 match resp.body {
666 ResponseBody::Bytes(b) => builder.body(Body::from(b)).unwrap(),
667 ResponseBody::File { file, size } => {
668 let stream = tokio_util::io::ReaderStream::new(file);
669 let body = Body::from_stream(stream);
670 if !has_content_length {
671 builder = builder.header("content-length", size.to_string());
672 }
673 builder.body(body).unwrap()
674 }
675 }
676 }
677 Err(err) => {
678 tracing::warn!(
679 service = %detected.service,
680 action = %detected.action,
681 error = %err,
682 "request failed"
683 );
684 let error_headers = err.response_headers().to_vec();
685 let mut resp = build_error_response_with_fields(
686 err.status(),
687 err.code(),
688 &err.message(),
689 &request_id,
690 detected.protocol,
691 err.extra_fields(),
692 );
693 for (k, v) in &error_headers {
694 if let (Ok(name), Ok(val)) = (
695 k.parse::<http::header::HeaderName>(),
696 v.parse::<http::header::HeaderValue>(),
697 ) {
698 resp.headers_mut().insert(name, val);
699 }
700 }
701 resp
702 }
703 }
704}
705
706#[derive(Clone)]
708pub struct DispatchConfig {
709 pub region: String,
710 pub account_id: String,
711 pub verify_sigv4: bool,
715 pub iam_mode: IamMode,
720 pub credential_resolver: Option<Arc<dyn CredentialResolver>>,
724 pub policy_evaluator: Option<Arc<dyn IamPolicyEvaluator>>,
728 pub resource_policy_provider: Option<Arc<dyn ResourcePolicyProvider>>,
735 pub scp_resolver: Option<Arc<dyn crate::auth::ScpResolver>>,
742}
743
744impl std::fmt::Debug for DispatchConfig {
745 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
746 f.debug_struct("DispatchConfig")
747 .field("region", &self.region)
748 .field("account_id", &self.account_id)
749 .field("verify_sigv4", &self.verify_sigv4)
750 .field("iam_mode", &self.iam_mode)
751 .field(
752 "credential_resolver",
753 &self
754 .credential_resolver
755 .as_ref()
756 .map(|_| "<CredentialResolver>"),
757 )
758 .field(
759 "policy_evaluator",
760 &self
761 .policy_evaluator
762 .as_ref()
763 .map(|_| "<IamPolicyEvaluator>"),
764 )
765 .field(
766 "resource_policy_provider",
767 &self
768 .resource_policy_provider
769 .as_ref()
770 .map(|_| "<ResourcePolicyProvider>"),
771 )
772 .field(
773 "scp_resolver",
774 &self.scp_resolver.as_ref().map(|_| "<ScpResolver>"),
775 )
776 .finish()
777 }
778}
779
780impl DispatchConfig {
781 pub fn new(region: impl Into<String>, account_id: impl Into<String>) -> Self {
784 Self {
785 region: region.into(),
786 account_id: account_id.into(),
787 verify_sigv4: false,
788 iam_mode: IamMode::Off,
789 credential_resolver: None,
790 policy_evaluator: None,
791 resource_policy_provider: None,
792 scp_resolver: None,
793 }
794 }
795}
796
797fn streaming_route(
817 method: &http::Method,
818 path: &str,
819 headers: &http::HeaderMap,
820 query_params: &HashMap<String, String>,
821) -> Option<(&'static str, &'static str)> {
822 if (method == http::Method::PATCH || method == http::Method::PUT)
824 && path.starts_with("/v2/")
825 && path.contains("/blobs/uploads/")
826 {
827 return Some(("ecr", ""));
828 }
829
830 if method == http::Method::PUT {
835 let after = path.trim_start_matches('/');
836 let virtual_hosted_s3 = protocol::parse_routing_host_from_headers(headers)
842 .filter(|h| h.service == "s3" && h.bucket.is_some())
843 .is_some();
844 if after.is_empty() || (!virtual_hosted_s3 && !after.contains('/')) {
845 return None;
846 }
847 let header_s3 = headers
848 .get("authorization")
849 .and_then(|v| v.to_str().ok())
850 .and_then(fakecloud_aws::sigv4::parse_sigv4)
851 .map(|info| info.service == "s3")
852 .unwrap_or(false);
853 let presigned_v4_s3 = query_params
854 .get("X-Amz-Credential")
855 .and_then(|c| c.split('/').nth(3).map(|s| s.to_string()))
856 .map(|service| service == "s3")
857 .unwrap_or(false);
858 let presigned_v2 = query_params.contains_key("AWSAccessKeyId")
859 && query_params.contains_key("Signature")
860 && query_params.contains_key("Expires");
861 if header_s3 || presigned_v4_s3 || presigned_v2 {
862 return Some(("s3", ""));
863 }
864 }
865
866 None
867}
868
869const DEFAULT_MAX_REQUEST_BODY_BYTES: usize = 1024 * 1024 * 1024;
879
880fn max_request_body_bytes() -> usize {
881 static CACHED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
882 *CACHED.get_or_init(|| {
883 std::env::var("FAKECLOUD_MAX_REQUEST_BODY_BYTES")
884 .ok()
885 .and_then(|s| s.parse::<usize>().ok())
886 .filter(|&n| n > 0)
887 .unwrap_or(DEFAULT_MAX_REQUEST_BODY_BYTES)
888 })
889}
890
891fn parse_account_from_arn(arn: &str) -> Option<String> {
897 let mut parts = arn.splitn(6, ':');
898 if parts.next()? != "arn" {
899 return None;
900 }
901 let _partition = parts.next()?;
902 let _service = parts.next()?;
903 let _region = parts.next()?;
904 let account = parts.next()?;
905 parts.next()?;
908 if account.is_empty() {
909 None
910 } else {
911 Some(account.to_string())
912 }
913}
914
915fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
917 let ua = headers.get("user-agent")?.to_str().ok()?;
918 for part in ua.split_whitespace() {
919 if let Some(region) = part.strip_prefix("region/") {
920 if !region.is_empty() {
921 return Some(region.to_string());
922 }
923 }
924 }
925 None
926}
927
928fn build_error_response(
929 status: StatusCode,
930 code: &str,
931 message: &str,
932 request_id: &str,
933 protocol: AwsProtocol,
934) -> Response<Body> {
935 build_error_response_with_fields(status, code, message, request_id, protocol, &[])
936}
937
938fn build_error_response_with_fields(
939 status: StatusCode,
940 code: &str,
941 message: &str,
942 request_id: &str,
943 protocol: AwsProtocol,
944 extra_fields: &[(String, String)],
945) -> Response<Body> {
946 let (status, content_type, body) = match protocol {
947 AwsProtocol::Query => {
948 fakecloud_aws::error::xml_error_response(status, code, message, request_id)
949 }
950 AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
951 status,
952 code,
953 message,
954 request_id,
955 extra_fields,
956 ),
957 AwsProtocol::Json | AwsProtocol::RestJson => {
958 fakecloud_aws::error::json_error_response(status, code, message)
959 }
960 };
961
962 let safe_code = sanitize_header_value(code);
972 let safe_message = sanitize_header_value(message);
973 let mut builder = Response::builder()
974 .status(status)
975 .header("content-type", content_type)
976 .header("x-amzn-requestid", request_id)
977 .header("x-amz-request-id", request_id);
978 if let Ok(v) = http::HeaderValue::from_str(&safe_code) {
979 builder = builder.header("x-amz-error-code", v);
980 }
981 if let Ok(v) = http::HeaderValue::from_str(&safe_message) {
982 builder = builder.header("x-amz-error-message", v);
983 }
984 builder.body(Body::from(body)).unwrap_or_else(|_| {
985 Response::new(Body::empty())
989 })
990}
991
992fn sanitize_header_value(s: &str) -> String {
997 const MAX_LEN: usize = 1024;
998 let mut out = String::with_capacity(s.len().min(MAX_LEN));
999 for ch in s.chars() {
1000 if out.len() >= MAX_LEN {
1001 break;
1002 }
1003 if ch.is_control() {
1006 if !out.ends_with(' ') {
1007 out.push(' ');
1008 }
1009 } else {
1010 out.push(ch);
1011 }
1012 }
1013 out.trim().to_string()
1014}
1015
1016fn build_condition_context(
1021 principal: &Principal,
1022 remote_addr: Option<SocketAddr>,
1023 region: &str,
1024 secure_transport: bool,
1025) -> ConditionContext {
1026 let now = chrono::Utc::now();
1027 ConditionContext {
1028 aws_username: aws_username_from_principal(principal),
1029 aws_userid: Some(principal.user_id.clone()),
1030 aws_principal_arn: Some(principal.arn.clone()),
1031 aws_principal_account: Some(principal.account_id.clone()),
1032 aws_principal_type: Some(principal_type_label(principal.principal_type).to_string()),
1033 aws_source_ip: remote_addr.map(|sa| sa.ip()),
1034 aws_current_time: Some(now),
1035 aws_epoch_time: Some(now.timestamp()),
1036 aws_secure_transport: Some(secure_transport),
1037 aws_requested_region: Some(region.to_string()),
1038 aws_mfa_present: None,
1044 aws_mfa_age_seconds: None,
1045 aws_called_via: Vec::new(),
1046 aws_source_vpce: None,
1047 aws_source_vpc: None,
1048 aws_vpc_source_ip: None,
1049 aws_federated_provider: None,
1050 aws_token_issue_time: None,
1051 service_keys: Default::default(),
1052 resource_tags: None,
1053 request_tags: None,
1054 principal_tags: None,
1055 }
1056}
1057
1058fn aws_username_from_principal(principal: &Principal) -> Option<String> {
1062 if principal.principal_type != PrincipalType::User {
1063 return None;
1064 }
1065 let after = principal.arn.rsplit_once(":user/").map(|(_, s)| s)?;
1066 Some(after.rsplit('/').next().unwrap_or(after).to_string())
1068}
1069
1070fn principal_type_label(t: PrincipalType) -> &'static str {
1073 match t {
1074 PrincipalType::User => "User",
1075 PrincipalType::AssumedRole => "AssumedRole",
1076 PrincipalType::FederatedUser => "FederatedUser",
1077 PrincipalType::Root => "Account",
1078 PrincipalType::Unknown => "Unknown",
1079 }
1080}
1081
1082fn is_secure_transport(headers: &http::HeaderMap) -> bool {
1088 headers
1089 .get("x-forwarded-proto")
1090 .and_then(|v| v.to_str().ok())
1091 .map(|s| s.eq_ignore_ascii_case("https"))
1092 .unwrap_or(false)
1093}
1094
1095trait ProtocolExt {
1096 fn error_status(&self) -> StatusCode;
1097}
1098
1099impl ProtocolExt for AwsProtocol {
1100 fn error_status(&self) -> StatusCode {
1101 StatusCode::BAD_REQUEST
1102 }
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107 use super::*;
1108
1109 #[test]
1110 fn default_max_request_body_bytes_is_one_gib() {
1111 assert_eq!(DEFAULT_MAX_REQUEST_BODY_BYTES, 1024 * 1024 * 1024);
1115 }
1116
1117 #[test]
1118 fn dispatch_config_new_defaults_to_off() {
1119 let cfg = DispatchConfig::new("us-east-1", "123456789012");
1120 assert_eq!(cfg.region, "us-east-1");
1121 assert_eq!(cfg.account_id, "123456789012");
1122 assert!(!cfg.verify_sigv4);
1123 assert_eq!(cfg.iam_mode, IamMode::Off);
1124 }
1125
1126 #[test]
1127 fn aws_username_strips_iam_path_for_users() {
1128 let p = Principal {
1129 arn: "arn:aws:iam::123456789012:user/engineering/alice".into(),
1130 user_id: "AIDAALICE".into(),
1131 account_id: "123456789012".into(),
1132 principal_type: PrincipalType::User,
1133 source_identity: None,
1134 tags: None,
1135 };
1136 assert_eq!(aws_username_from_principal(&p), Some("alice".into()));
1137 }
1138
1139 #[test]
1140 fn aws_username_unset_for_assumed_role() {
1141 let p = Principal {
1142 arn: "arn:aws:sts::123456789012:assumed-role/ops/session".into(),
1143 user_id: "AROAOPS:session".into(),
1144 account_id: "123456789012".into(),
1145 principal_type: PrincipalType::AssumedRole,
1146 source_identity: None,
1147 tags: None,
1148 };
1149 assert_eq!(aws_username_from_principal(&p), None);
1150 }
1151
1152 #[test]
1153 fn principal_type_label_matches_aws_casing() {
1154 assert_eq!(principal_type_label(PrincipalType::User), "User");
1155 assert_eq!(
1156 principal_type_label(PrincipalType::AssumedRole),
1157 "AssumedRole"
1158 );
1159 assert_eq!(principal_type_label(PrincipalType::Root), "Account");
1160 }
1161
1162 #[test]
1163 fn build_condition_context_populates_global_keys() {
1164 let p = Principal {
1165 arn: "arn:aws:iam::123456789012:user/alice".into(),
1166 user_id: "AIDAALICE".into(),
1167 account_id: "123456789012".into(),
1168 principal_type: PrincipalType::User,
1169 source_identity: None,
1170 tags: None,
1171 };
1172 let addr: SocketAddr = "10.0.0.1:54321".parse().unwrap();
1173 let ctx = build_condition_context(&p, Some(addr), "us-east-1", false);
1174 assert_eq!(ctx.aws_username.as_deref(), Some("alice"));
1175 assert_eq!(ctx.aws_userid.as_deref(), Some("AIDAALICE"));
1176 assert_eq!(
1177 ctx.aws_principal_arn.as_deref(),
1178 Some("arn:aws:iam::123456789012:user/alice")
1179 );
1180 assert_eq!(ctx.aws_principal_account.as_deref(), Some("123456789012"));
1181 assert_eq!(ctx.aws_principal_type.as_deref(), Some("User"));
1182 assert_eq!(
1183 ctx.aws_source_ip.map(|i| i.to_string()).as_deref(),
1184 Some("10.0.0.1")
1185 );
1186 assert_eq!(ctx.aws_requested_region.as_deref(), Some("us-east-1"));
1187 assert_eq!(ctx.aws_secure_transport, Some(false));
1188 assert!(ctx.aws_current_time.is_some());
1189 assert!(ctx.aws_epoch_time.is_some());
1190 }
1191
1192 #[test]
1193 fn is_secure_transport_reads_x_forwarded_proto() {
1194 let mut headers = http::HeaderMap::new();
1195 headers.insert("x-forwarded-proto", "https".parse().unwrap());
1196 assert!(is_secure_transport(&headers));
1197 headers.insert("x-forwarded-proto", "http".parse().unwrap());
1198 assert!(!is_secure_transport(&headers));
1199 let empty = http::HeaderMap::new();
1200 assert!(!is_secure_transport(&empty));
1201 }
1202
1203 #[test]
1204 fn parse_account_from_arn_extracts_standard_shapes() {
1205 assert_eq!(
1206 parse_account_from_arn("arn:aws:sqs:us-east-1:123456789012:queue"),
1207 Some("123456789012".to_string())
1208 );
1209 assert_eq!(
1210 parse_account_from_arn("arn:aws:iam::123456789012:user/alice"),
1211 Some("123456789012".to_string())
1212 );
1213 }
1214
1215 #[test]
1216 fn parse_account_from_arn_returns_none_for_s3_empty_account() {
1217 assert_eq!(parse_account_from_arn("arn:aws:s3:::my-bucket"), None);
1219 assert_eq!(
1220 parse_account_from_arn("arn:aws:s3:::my-bucket/path/to/key"),
1221 None
1222 );
1223 }
1224
1225 #[test]
1226 fn parse_account_from_arn_returns_none_for_malformed() {
1227 assert_eq!(parse_account_from_arn(""), None);
1228 assert_eq!(parse_account_from_arn("not-an-arn"), None);
1229 assert_eq!(parse_account_from_arn("arn:aws:sqs:us-east-1"), None);
1230 assert_eq!(parse_account_from_arn("arn:aws:sqs"), None);
1231 }
1232
1233 #[test]
1234 fn extract_region_from_user_agent_finds_region_segment() {
1235 let mut headers = http::HeaderMap::new();
1236 headers.insert(
1237 "user-agent",
1238 "aws-sdk-rust/1.0 os/linux region/eu-central-1"
1239 .parse()
1240 .unwrap(),
1241 );
1242 assert_eq!(
1243 extract_region_from_user_agent(&headers),
1244 Some("eu-central-1".to_string())
1245 );
1246 }
1247
1248 #[test]
1249 fn extract_region_from_user_agent_none_without_header() {
1250 let headers = http::HeaderMap::new();
1251 assert_eq!(extract_region_from_user_agent(&headers), None);
1252 }
1253
1254 #[test]
1255 fn extract_region_from_user_agent_ignores_empty_region() {
1256 let mut headers = http::HeaderMap::new();
1257 headers.insert("user-agent", "aws-sdk-java region/".parse().unwrap());
1258 assert_eq!(extract_region_from_user_agent(&headers), None);
1259 }
1260
1261 #[test]
1262 fn extract_region_from_user_agent_none_when_no_region_marker() {
1263 let mut headers = http::HeaderMap::new();
1264 headers.insert("user-agent", "curl/7.79.1".parse().unwrap());
1265 assert_eq!(extract_region_from_user_agent(&headers), None);
1266 }
1267
1268 #[test]
1269 fn aws_username_none_for_root() {
1270 let p = Principal {
1271 arn: "arn:aws:iam::123456789012:root".into(),
1272 user_id: "123456789012".into(),
1273 account_id: "123456789012".into(),
1274 principal_type: PrincipalType::Root,
1275 source_identity: None,
1276 tags: None,
1277 };
1278 assert_eq!(aws_username_from_principal(&p), None);
1279 }
1280
1281 #[test]
1282 fn aws_username_bare_no_path() {
1283 let p = Principal {
1284 arn: "arn:aws:iam::123456789012:user/bob".into(),
1285 user_id: "AIDABOB".into(),
1286 account_id: "123456789012".into(),
1287 principal_type: PrincipalType::User,
1288 source_identity: None,
1289 tags: None,
1290 };
1291 assert_eq!(aws_username_from_principal(&p), Some("bob".into()));
1292 }
1293
1294 #[test]
1295 fn principal_type_label_covers_federated_and_unknown() {
1296 assert_eq!(
1297 principal_type_label(PrincipalType::FederatedUser),
1298 "FederatedUser"
1299 );
1300 assert_eq!(principal_type_label(PrincipalType::Unknown), "Unknown");
1301 }
1302
1303 #[test]
1304 fn build_condition_context_marks_secure_when_flag_set() {
1305 let p = Principal {
1306 arn: "arn:aws:iam::123456789012:user/alice".into(),
1307 user_id: "AIDAALICE".into(),
1308 account_id: "123456789012".into(),
1309 principal_type: PrincipalType::User,
1310 source_identity: None,
1311 tags: None,
1312 };
1313 let ctx = build_condition_context(&p, None, "us-west-2", true);
1314 assert_eq!(ctx.aws_secure_transport, Some(true));
1315 assert!(ctx.aws_source_ip.is_none());
1316 assert_eq!(ctx.aws_requested_region.as_deref(), Some("us-west-2"));
1317 }
1318
1319 #[test]
1320 fn is_secure_transport_case_insensitive() {
1321 let mut headers = http::HeaderMap::new();
1322 headers.insert("x-forwarded-proto", "HTTPS".parse().unwrap());
1323 assert!(is_secure_transport(&headers));
1324 }
1325
1326 #[test]
1327 fn is_secure_transport_non_ascii_bytes_false() {
1328 let mut headers = http::HeaderMap::new();
1329 headers.insert(
1330 "x-forwarded-proto",
1331 http::HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap(),
1332 );
1333 assert!(!is_secure_transport(&headers));
1334 }
1335
1336 #[test]
1337 fn protocol_ext_error_status_is_bad_request() {
1338 assert_eq!(AwsProtocol::Query.error_status(), StatusCode::BAD_REQUEST);
1339 assert_eq!(AwsProtocol::Json.error_status(), StatusCode::BAD_REQUEST);
1340 assert_eq!(AwsProtocol::Rest.error_status(), StatusCode::BAD_REQUEST);
1341 assert_eq!(
1342 AwsProtocol::RestJson.error_status(),
1343 StatusCode::BAD_REQUEST
1344 );
1345 }
1346
1347 #[test]
1348 fn build_error_response_json_has_json_content_type() {
1349 let resp = build_error_response(
1350 StatusCode::BAD_REQUEST,
1351 "TestCode",
1352 "test msg",
1353 "req-1",
1354 AwsProtocol::Json,
1355 );
1356 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
1357 let ct = resp
1358 .headers()
1359 .get("content-type")
1360 .unwrap()
1361 .to_str()
1362 .unwrap();
1363 assert!(ct.contains("json"));
1364 let rid = resp
1365 .headers()
1366 .get("x-amzn-requestid")
1367 .unwrap()
1368 .to_str()
1369 .unwrap();
1370 assert_eq!(rid, "req-1");
1371 }
1372
1373 #[test]
1374 fn build_error_response_rest_returns_xml_content_type() {
1375 let resp = build_error_response(
1376 StatusCode::NOT_FOUND,
1377 "NoSuchBucket",
1378 "bucket missing",
1379 "req-2",
1380 AwsProtocol::Rest,
1381 );
1382 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
1383 let ct = resp
1384 .headers()
1385 .get("content-type")
1386 .unwrap()
1387 .to_str()
1388 .unwrap();
1389 assert!(ct.contains("xml"));
1390 }
1391
1392 #[test]
1393 fn build_error_response_query_returns_xml() {
1394 let resp = build_error_response(
1395 StatusCode::BAD_REQUEST,
1396 "InvalidParameter",
1397 "bad param",
1398 "req-3",
1399 AwsProtocol::Query,
1400 );
1401 let ct = resp
1402 .headers()
1403 .get("content-type")
1404 .unwrap()
1405 .to_str()
1406 .unwrap();
1407 assert!(ct.contains("xml"));
1408 }
1409
1410 #[test]
1415 fn build_error_response_with_multiline_message_does_not_panic() {
1416 let resp = build_error_response(
1417 StatusCode::INTERNAL_SERVER_ERROR,
1418 "ServiceException",
1419 "Lambda execution failed: container failed to start: docker start failed: \
1420 Error: unable to start container \"abc\": \
1421 failed to create new hosts file:\nhost-gateway is empty\n",
1422 "req-multi",
1423 AwsProtocol::Json,
1424 );
1425 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
1426 let msg = resp
1427 .headers()
1428 .get("x-amz-error-message")
1429 .expect("x-amz-error-message must be set even when input contains newlines")
1430 .to_str()
1431 .unwrap();
1432 assert!(!msg.contains('\n'));
1433 assert!(!msg.contains('\r'));
1434 assert!(msg.contains("Lambda execution failed"));
1435 assert!(msg.contains("host-gateway is empty"));
1436 }
1437
1438 #[test]
1439 fn build_error_response_with_control_chars_strips_them() {
1440 let resp = build_error_response(
1441 StatusCode::BAD_REQUEST,
1442 "Code\twith\ttabs",
1443 "msg\x00with\x01nulls",
1444 "req-ctrl",
1445 AwsProtocol::Json,
1446 );
1447 let code = resp
1448 .headers()
1449 .get("x-amz-error-code")
1450 .unwrap()
1451 .to_str()
1452 .unwrap();
1453 let msg = resp
1454 .headers()
1455 .get("x-amz-error-message")
1456 .unwrap()
1457 .to_str()
1458 .unwrap();
1459 assert!(!code.contains('\t'));
1460 assert!(!msg.contains('\x00'));
1461 assert!(!msg.contains('\x01'));
1462 }
1463
1464 #[test]
1465 fn sanitize_header_value_truncates_long_input() {
1466 let huge = "x".repeat(5_000);
1467 let out = sanitize_header_value(&huge);
1468 assert!(out.len() <= 1024);
1469 }
1470
1471 #[test]
1472 fn sanitize_header_value_collapses_consecutive_control_runs() {
1473 let out = sanitize_header_value("a\n\n\n\rb");
1474 assert_eq!(out, "a b");
1475 }
1476
1477 #[test]
1478 fn dispatch_config_carries_opt_in_flags() {
1479 let cfg = DispatchConfig {
1480 region: "eu-west-1".to_string(),
1481 account_id: "000000000000".to_string(),
1482 verify_sigv4: true,
1483 iam_mode: IamMode::Strict,
1484 credential_resolver: None,
1485 policy_evaluator: None,
1486 resource_policy_provider: None,
1487 scp_resolver: None,
1488 };
1489 assert!(cfg.verify_sigv4);
1490 assert!(cfg.iam_mode.is_strict());
1491 assert!(cfg.resource_policy_provider.is_none());
1492 assert!(cfg.scp_resolver.is_none());
1493 }
1494
1495 fn s3_sigv4_headers() -> http::HeaderMap {
1496 let mut headers = http::HeaderMap::new();
1497 headers.insert(
1498 "authorization",
1499 "AWS4-HMAC-SHA256 Credential=test/20240101/us-east-1/s3/aws4_request, \
1500 SignedHeaders=host, Signature=fake"
1501 .parse()
1502 .unwrap(),
1503 );
1504 headers
1505 }
1506
1507 #[test]
1508 fn streaming_route_path_style_s3_put_object() {
1509 let headers = s3_sigv4_headers();
1510 assert_eq!(
1511 streaming_route(
1512 &http::Method::PUT,
1513 "/my-bucket/key.txt",
1514 &headers,
1515 &HashMap::new(),
1516 ),
1517 Some(("s3", "")),
1518 );
1519 }
1520
1521 #[test]
1522 fn streaming_route_path_style_create_bucket_skipped() {
1523 let headers = s3_sigv4_headers();
1526 assert_eq!(
1527 streaming_route(&http::Method::PUT, "/my-bucket", &headers, &HashMap::new(),),
1528 None,
1529 );
1530 }
1531
1532 #[test]
1533 fn streaming_route_virtual_hosted_s3_put_object() {
1534 let mut headers = s3_sigv4_headers();
1535 headers.insert(
1536 "host",
1537 "vhost-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1538 .parse()
1539 .unwrap(),
1540 );
1541 assert_eq!(
1546 streaming_route(&http::Method::PUT, "/hello.txt", &headers, &HashMap::new(),),
1547 Some(("s3", "")),
1548 );
1549 }
1550
1551 #[test]
1552 fn streaming_route_virtual_hosted_s3_root_skipped() {
1553 let mut headers = s3_sigv4_headers();
1556 headers.insert(
1557 "host",
1558 "vhost-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1559 .parse()
1560 .unwrap(),
1561 );
1562 assert_eq!(
1563 streaming_route(&http::Method::PUT, "/", &headers, &HashMap::new()),
1564 None,
1565 );
1566 }
1567
1568 #[test]
1569 fn streaming_route_ecr_blob_upload() {
1570 let headers = http::HeaderMap::new();
1571 assert_eq!(
1572 streaming_route(
1573 &http::Method::PATCH,
1574 "/v2/my-repo/blobs/uploads/abcd1234",
1575 &headers,
1576 &HashMap::new(),
1577 ),
1578 Some(("ecr", "")),
1579 );
1580 assert_eq!(
1581 streaming_route(
1582 &http::Method::PUT,
1583 "/v2/my-repo/blobs/uploads/abcd1234",
1584 &headers,
1585 &HashMap::new(),
1586 ),
1587 Some(("ecr", "")),
1588 );
1589 }
1590
1591 #[test]
1592 fn streaming_route_presigned_v4_s3_put() {
1593 let headers = http::HeaderMap::new();
1594 let mut query_params = HashMap::new();
1595 query_params.insert(
1596 "X-Amz-Credential".to_string(),
1597 "test/20240101/us-east-1/s3/aws4_request".to_string(),
1598 );
1599 assert_eq!(
1600 streaming_route(
1601 &http::Method::PUT,
1602 "/my-bucket/key.txt",
1603 &headers,
1604 &query_params,
1605 ),
1606 Some(("s3", "")),
1607 );
1608 }
1609
1610 #[test]
1611 fn streaming_route_non_s3_auth_header_skipped() {
1612 let mut headers = http::HeaderMap::new();
1615 headers.insert(
1616 "authorization",
1617 "AWS4-HMAC-SHA256 Credential=test/20240101/us-east-1/lambda/aws4_request, \
1618 SignedHeaders=host, Signature=fake"
1619 .parse()
1620 .unwrap(),
1621 );
1622 assert_eq!(
1623 streaming_route(
1624 &http::Method::PUT,
1625 "/my-bucket/key.txt",
1626 &headers,
1627 &HashMap::new(),
1628 ),
1629 None,
1630 );
1631 }
1632
1633 #[test]
1634 fn streaming_route_get_skipped() {
1635 let headers = s3_sigv4_headers();
1636 assert_eq!(
1637 streaming_route(
1638 &http::Method::GET,
1639 "/my-bucket/key.txt",
1640 &headers,
1641 &HashMap::new(),
1642 ),
1643 None,
1644 );
1645 }
1646}