1use axum::body::Body;
2use axum::extract::{ConnectInfo, Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use bytes::Bytes;
6use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9
10use crate::auth::{
11 is_root_bypass, ConditionContext, CredentialResolver, IamMode, IamPolicyEvaluator, Principal,
12 PrincipalType, ResourcePolicyProvider,
13};
14use crate::protocol::{self, AwsProtocol};
15use crate::registry::ServiceRegistry;
16use crate::service::{AwsRequest, ResponseBody};
17
18pub async fn dispatch(
20 ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
21 Extension(registry): Extension<Arc<ServiceRegistry>>,
22 Extension(config): Extension<Arc<DispatchConfig>>,
23 Query(query_params): Query<HashMap<String, String>>,
24 request: Request<Body>,
25) -> Response<Body> {
26 let remote_addr = Some(remote_addr);
27 let request_id = uuid::Uuid::new_v4().to_string();
28
29 let (parts, body) = request.into_parts();
30
31 let stream_route = streaming_route(
37 &parts.method,
38 parts.uri.path(),
39 &parts.headers,
40 &query_params,
41 );
42 let header_only = protocol::detect_service_headers_only(&parts.headers, &query_params);
43 let stream_dispatch = match (&stream_route, &header_only) {
44 (Some(sr), Some(detected)) if sr.0 == detected.service => Some(detected.clone()),
47 (Some((service, _)), None) if *service == "ecr" => Some(protocol::DetectedRequest {
53 service: "ecr".to_string(),
54 action: String::new(),
55 protocol: AwsProtocol::Rest,
56 }),
57 _ => None,
58 };
59
60 let (body_bytes, body_stream) = if stream_dispatch.is_some() {
61 (Bytes::new(), Some(body))
62 } else {
63 let max_body_bytes = max_request_body_bytes();
68 match axum::body::to_bytes(body, max_body_bytes).await {
69 Ok(b) => (b, None),
70 Err(_) => {
71 return build_error_response(
72 StatusCode::PAYLOAD_TOO_LARGE,
73 "RequestEntityTooLarge",
74 "Request body too large",
75 &request_id,
76 AwsProtocol::Query,
77 );
78 }
79 }
80 };
81
82 let detected = if let Some(d) = stream_dispatch {
84 d
85 } else {
86 match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
87 Some(d) => d,
88 None => {
89 if parts.method == http::Method::OPTIONS {
95 protocol::DetectedRequest {
96 service: "s3".to_string(),
97 action: String::new(),
98 protocol: AwsProtocol::Rest,
99 }
100 } else if parts.uri.path() == "/v2" || parts.uri.path().starts_with("/v2/") {
101 protocol::DetectedRequest {
105 service: "ecr".to_string(),
106 action: String::new(),
107 protocol: AwsProtocol::Rest,
108 }
109 } else if !parts.uri.path().starts_with("/_") {
110 protocol::DetectedRequest {
115 service: "apigateway".to_string(),
116 action: String::new(),
117 protocol: AwsProtocol::RestJson,
118 }
119 } else {
120 return build_error_response(
121 StatusCode::BAD_REQUEST,
122 "MissingAction",
123 "Could not determine target service or action from request",
124 &request_id,
125 AwsProtocol::Query,
126 );
127 }
128 }
129 }
130 };
131
132 let detected = if detected.service == "bedrock" {
136 let first_seg = parts.uri.path().split('/').nth(1);
137 if matches!(
138 first_seg,
139 Some(
140 "agents"
141 | "knowledgebases"
142 | "flows"
143 | "prompts"
144 | "tags"
145 | "retrieveAndGenerate"
146 | "retrieveAndGenerateStream"
147 | "optimize-prompt"
148 | "sessions"
149 | "invocations"
150 | "generate-query"
151 | "rerank"
152 )
153 ) {
154 let segs: Vec<&str> = parts.uri.path().split('/').collect();
156 let is_runtime = matches!(
157 segs.as_slice(),
158 ["", "agents", _, "agentAliases", _, ..] | ["", "flows", _, "aliases", _] | ["", "knowledgebases", _, "retrieve"] | ["", "retrieveAndGenerate"]
162 | ["", "retrieveAndGenerateStream"]
163 | ["", "optimize-prompt"]
164 | ["", "sessions", ..]
165 | ["", "invocations", ..]
166 | ["", "generate-query"]
167 | ["", "rerank"]
168 );
169 if is_runtime {
170 protocol::DetectedRequest {
171 service: "bedrock-agent-runtime".to_string(),
172 ..detected
173 }
174 } else {
175 protocol::DetectedRequest {
176 service: "bedrock-agent".to_string(),
177 ..detected
178 }
179 }
180 } else {
181 detected
182 }
183 } else {
184 detected
185 };
186
187 let service = match registry.get(&detected.service) {
189 Some(s) => s,
190 None => {
191 return build_error_response(
192 detected.protocol.error_status(),
193 "UnknownService",
194 &format!("Service '{}' is not available", detected.service),
195 &request_id,
196 detected.protocol,
197 );
198 }
199 };
200
201 let auth_header = parts
203 .headers
204 .get("authorization")
205 .and_then(|v| v.to_str().ok())
206 .unwrap_or("");
207 let header_info = fakecloud_aws::sigv4::parse_sigv4(auth_header);
208 let presigned_info = if header_info.is_none() {
209 fakecloud_aws::sigv4::parse_sigv4_presigned(&query_params).map(|p| p.as_info())
211 } else {
212 None
213 };
214 let sigv4_info = header_info.or(presigned_info);
215 let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
216
217 let host_info = protocol::parse_routing_host_from_headers(&parts.headers);
223
224 let region = sigv4_info
225 .map(|info| info.region)
226 .or_else(|| host_info.as_ref().map(|h| h.region.clone()))
227 .or_else(|| extract_region_from_user_agent(&parts.headers))
228 .unwrap_or_else(|| config.region.clone());
229
230 let caller_akid = access_key_id.as_deref().unwrap_or("");
236 let resolved = if !caller_akid.is_empty() && !is_root_bypass(caller_akid) {
237 config
238 .credential_resolver
239 .as_ref()
240 .and_then(|r| r.resolve(caller_akid))
241 } else {
242 None
243 };
244 let caller_principal = resolved.as_ref().map(|r| r.principal.clone());
245 let caller_session_policies = resolved
246 .as_ref()
247 .map(|r| r.session_policies.clone())
248 .unwrap_or_default();
249
250 if config.verify_sigv4 && !is_root_bypass(caller_akid) && config.credential_resolver.is_some() {
255 let amz_date = parts
256 .headers
257 .get("x-amz-date")
258 .and_then(|v| v.to_str().ok());
259 let parsed = fakecloud_aws::sigv4::parse_sigv4_header(auth_header, amz_date)
260 .or_else(|| fakecloud_aws::sigv4::parse_sigv4_presigned(&query_params));
261 let parsed = match parsed {
262 Some(p) => p,
263 None => {
264 return build_error_response(
265 StatusCode::FORBIDDEN,
266 "IncompleteSignature",
267 "Request is missing or has a malformed AWS Signature",
268 &request_id,
269 detected.protocol,
270 );
271 }
272 };
273 let resolved_for_verify = match resolved.as_ref() {
274 Some(r) => r,
275 None => {
276 return build_error_response(
277 StatusCode::FORBIDDEN,
278 "InvalidClientTokenId",
279 "The security token included in the request is invalid",
280 &request_id,
281 detected.protocol,
282 );
283 }
284 };
285 let headers_vec = fakecloud_aws::sigv4::headers_from_http(&parts.headers);
286 let raw_query_for_verify = parts.uri.query().unwrap_or("").to_string();
287 let verify_req = fakecloud_aws::sigv4::VerifyRequest {
288 method: parts.method.as_str(),
289 path: parts.uri.path(),
290 query: &raw_query_for_verify,
291 headers: &headers_vec,
292 body: &body_bytes,
293 };
294 match fakecloud_aws::sigv4::verify(
295 &parsed,
296 &verify_req,
297 &resolved_for_verify.secret_access_key,
298 chrono::Utc::now(),
299 ) {
300 Ok(()) => {}
301 Err(fakecloud_aws::sigv4::SigV4Error::RequestTimeTooSkewed { .. }) => {
302 return build_error_response(
303 StatusCode::FORBIDDEN,
304 "RequestTimeTooSkewed",
305 "The difference between the request time and the current time is too large",
306 &request_id,
307 detected.protocol,
308 );
309 }
310 Err(fakecloud_aws::sigv4::SigV4Error::InvalidDate(msg)) => {
311 return build_error_response(
312 StatusCode::FORBIDDEN,
313 "IncompleteSignature",
314 &format!("Invalid x-amz-date: {msg}"),
315 &request_id,
316 detected.protocol,
317 );
318 }
319 Err(fakecloud_aws::sigv4::SigV4Error::Malformed(msg)) => {
320 return build_error_response(
321 StatusCode::FORBIDDEN,
322 "IncompleteSignature",
323 &format!("Malformed SigV4 signature: {msg}"),
324 &request_id,
325 detected.protocol,
326 );
327 }
328 Err(fakecloud_aws::sigv4::SigV4Error::SignatureMismatch) => {
329 return build_error_response(
330 StatusCode::FORBIDDEN,
331 "SignatureDoesNotMatch",
332 "The request signature we calculated does not match the signature you provided",
333 &request_id,
334 detected.protocol,
335 );
336 }
337 }
338 }
339
340 let wire_path = parts.uri.path();
345 let path = if detected.service == "s3" {
346 if let Some(bucket) = host_info.as_ref().and_then(|h| h.bucket.as_deref()) {
347 let prefix_with_slash = format!("/{bucket}/");
348 let is_bucket_root = wire_path.trim_end_matches('/') == format!("/{bucket}");
349 if wire_path.starts_with(&prefix_with_slash) || is_bucket_root {
350 wire_path.to_string()
351 } else if wire_path == "/" || wire_path.is_empty() {
352 format!("/{bucket}")
353 } else {
354 format!("/{bucket}{wire_path}")
355 }
356 } else {
357 wire_path.to_string()
358 }
359 } else {
360 wire_path.to_string()
361 };
362 let raw_query = parts.uri.query().unwrap_or("").to_string();
363 let path_segments: Vec<String> = path
364 .split('/')
365 .filter(|s| !s.is_empty())
366 .map(|s| s.to_string())
367 .collect();
368
369 if detected.protocol == AwsProtocol::Json
371 && !body_bytes.is_empty()
372 && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
373 {
374 return build_error_response(
375 StatusCode::BAD_REQUEST,
376 "SerializationException",
377 "Start of structure or map found where not expected",
378 &request_id,
379 AwsProtocol::Json,
380 );
381 }
382
383 let mut all_params = query_params;
385 if detected.protocol == AwsProtocol::Query {
386 let body_params = protocol::parse_query_body(&body_bytes);
387 for (k, v) in body_params {
388 all_params.entry(k).or_insert(v);
389 }
390 }
391
392 let aws_request = AwsRequest {
393 service: detected.service.clone(),
394 action: detected.action.clone(),
395 region,
396 account_id: caller_principal
397 .as_ref()
398 .map(|p| p.account_id.clone())
399 .unwrap_or_else(|| config.account_id.clone()),
400 request_id: request_id.clone(),
401 headers: parts.headers,
402 query_params: all_params,
403 body: body_bytes,
404 body_stream: parking_lot::Mutex::new(body_stream),
405 path_segments,
406 raw_path: path,
407 raw_query,
408 method: parts.method,
409 is_query_protocol: detected.protocol == AwsProtocol::Query,
410 access_key_id,
411 principal: caller_principal,
412 };
413
414 tracing::info!(
415 service = %aws_request.service,
416 action = %aws_request.action,
417 request_id = %aws_request.request_id,
418 "handling request"
419 );
420
421 if config.iam_mode.is_enabled()
428 && service.iam_enforceable()
429 && !is_root_bypass(aws_request.access_key_id.as_deref().unwrap_or(""))
430 {
431 if let Some(evaluator) = config.policy_evaluator.as_ref() {
432 if let Some(principal) = aws_request.principal.as_ref() {
433 if !principal.is_root() {
434 if let Some(iam_action) = service.iam_action_for(&aws_request) {
435 let mut condition_context = build_condition_context(
436 principal,
437 remote_addr,
438 &aws_request.region,
439 is_secure_transport(&aws_request.headers),
440 );
441 if let Some(rc) = resolved.as_ref() {
449 condition_context.aws_mfa_present = Some(rc.mfa_present);
450 condition_context.aws_token_issue_time = rc.token_issued_at;
451 condition_context.aws_federated_provider =
452 rc.federated_provider.clone();
453 if rc.mfa_present {
461 if let Some(issued) = rc.token_issued_at {
462 let age = chrono::Utc::now()
463 .signed_duration_since(issued)
464 .num_seconds()
465 .max(0);
466 condition_context.aws_mfa_age_seconds = Some(age);
467 }
468 }
469 }
470 condition_context.service_keys =
471 service.iam_condition_keys_for(&aws_request, &iam_action);
472
473 match service.resource_tags_for(&iam_action.resource) {
476 Some(tags) => condition_context.resource_tags = Some(tags),
477 None => tracing::debug!(
478 target: "fakecloud::iam::audit",
479 service = %detected.service,
480 resource = %iam_action.resource,
481 "service does not expose resource tags for ABAC; skipping aws:ResourceTag/* evaluation"
482 ),
483 }
484 match service.request_tags_from(&aws_request, iam_action.action) {
486 Some(tags) => condition_context.request_tags = Some(tags),
487 None => tracing::debug!(
488 target: "fakecloud::iam::audit",
489 service = %detected.service,
490 action = %iam_action.action_string(),
491 "service does not expose request tags for ABAC; skipping aws:RequestTag/* / aws:TagKeys evaluation"
492 ),
493 }
494 condition_context.principal_tags = principal.tags.clone();
496
497 let resource_policy_json =
506 config.resource_policy_provider.as_ref().and_then(|p| {
507 p.resource_policy(&detected.service, &iam_action.resource)
508 });
509 let resource_account_id = parse_account_from_arn(&iam_action.resource)
515 .unwrap_or_else(|| principal.account_id.clone());
516 let scps = config
523 .scp_resolver
524 .as_ref()
525 .and_then(|r| r.scps_for(principal));
526 let decision = evaluator.evaluate_with_resource_policy(
527 principal,
528 &iam_action,
529 &condition_context,
530 resource_policy_json.as_deref(),
531 &resource_account_id,
532 &caller_session_policies,
533 scps.as_deref(),
534 );
535 if !decision.is_allow() {
536 tracing::warn!(
537 target: "fakecloud::iam::audit",
538 service = %detected.service,
539 action = %iam_action.action_string(),
540 resource = %iam_action.resource,
541 principal = %principal.arn,
542 resource_policy_present = resource_policy_json.is_some(),
543 decision = ?decision,
544 mode = %config.iam_mode,
545 request_id = %request_id,
546 "IAM policy evaluation denied request"
547 );
548 if config.iam_mode.is_strict() {
549 let context_summary = serde_json::json!({
562 "aws:PrincipalArn": principal.arn,
563 "aws:PrincipalAccount": principal.account_id,
564 "aws:RequestedRegion": condition_context
565 .aws_requested_region
566 .clone()
567 .unwrap_or_default(),
568 "aws:SecureTransport": condition_context
569 .aws_secure_transport
570 .unwrap_or(false),
571 "aws:Action": iam_action.action_string(),
572 "aws:Resource": iam_action.resource,
573 "decision": format!("{:?}", decision),
574 });
575 let action_string = iam_action.action_string();
576 let encoded = crate::auth_message::encode_deny(
577 matches!(decision, crate::auth::IamDecision::ExplicitDeny),
578 Some(&action_string),
579 Some(&principal.arn),
580 Vec::new(),
581 Some(context_summary),
582 );
583 return build_error_response(
584 StatusCode::FORBIDDEN,
585 "AccessDeniedException",
586 &format!(
587 "User: {} is not authorized to perform: {} on resource: {} Encoded authorization failure message: {}",
588 principal.arn,
589 iam_action.action_string(),
590 iam_action.resource,
591 encoded,
592 ),
593 &request_id,
594 detected.protocol,
595 );
596 }
597 }
600 } else {
601 tracing::warn!(
606 target: "fakecloud::iam::audit",
607 service = %detected.service,
608 action = %aws_request.action,
609 "service is iam_enforceable but has no IamAction mapping for this action; skipping evaluation"
610 );
611 }
612 }
613 }
614 }
615 }
616
617 match service.handle(aws_request).await {
618 Ok(resp) => {
619 let mut builder = Response::builder()
620 .status(resp.status)
621 .header("x-amzn-requestid", &request_id)
622 .header("x-amz-request-id", &request_id);
623
624 if !resp.content_type.is_empty() {
625 builder = builder.header("content-type", &resp.content_type);
626 }
627
628 let has_content_length = resp
629 .headers
630 .iter()
631 .any(|(k, _)| k.as_str().eq_ignore_ascii_case("content-length"));
632
633 for (k, v) in &resp.headers {
634 builder = builder.header(k, v);
635 }
636
637 match resp.body {
638 ResponseBody::Bytes(b) => builder.body(Body::from(b)).unwrap(),
639 ResponseBody::File { file, size } => {
640 let stream = tokio_util::io::ReaderStream::new(file);
641 let body = Body::from_stream(stream);
642 if !has_content_length {
643 builder = builder.header("content-length", size.to_string());
644 }
645 builder.body(body).unwrap()
646 }
647 }
648 }
649 Err(err) => {
650 tracing::warn!(
651 service = %detected.service,
652 action = %detected.action,
653 error = %err,
654 "request failed"
655 );
656 let error_headers = err.response_headers().to_vec();
657 let mut resp = build_error_response_with_fields(
658 err.status(),
659 err.code(),
660 &err.message(),
661 &request_id,
662 detected.protocol,
663 err.extra_fields(),
664 );
665 for (k, v) in &error_headers {
666 if let (Ok(name), Ok(val)) = (
667 k.parse::<http::header::HeaderName>(),
668 v.parse::<http::header::HeaderValue>(),
669 ) {
670 resp.headers_mut().insert(name, val);
671 }
672 }
673 resp
674 }
675 }
676}
677
678#[derive(Clone)]
680pub struct DispatchConfig {
681 pub region: String,
682 pub account_id: String,
683 pub verify_sigv4: bool,
687 pub iam_mode: IamMode,
692 pub credential_resolver: Option<Arc<dyn CredentialResolver>>,
696 pub policy_evaluator: Option<Arc<dyn IamPolicyEvaluator>>,
700 pub resource_policy_provider: Option<Arc<dyn ResourcePolicyProvider>>,
707 pub scp_resolver: Option<Arc<dyn crate::auth::ScpResolver>>,
714}
715
716impl std::fmt::Debug for DispatchConfig {
717 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
718 f.debug_struct("DispatchConfig")
719 .field("region", &self.region)
720 .field("account_id", &self.account_id)
721 .field("verify_sigv4", &self.verify_sigv4)
722 .field("iam_mode", &self.iam_mode)
723 .field(
724 "credential_resolver",
725 &self
726 .credential_resolver
727 .as_ref()
728 .map(|_| "<CredentialResolver>"),
729 )
730 .field(
731 "policy_evaluator",
732 &self
733 .policy_evaluator
734 .as_ref()
735 .map(|_| "<IamPolicyEvaluator>"),
736 )
737 .field(
738 "resource_policy_provider",
739 &self
740 .resource_policy_provider
741 .as_ref()
742 .map(|_| "<ResourcePolicyProvider>"),
743 )
744 .field(
745 "scp_resolver",
746 &self.scp_resolver.as_ref().map(|_| "<ScpResolver>"),
747 )
748 .finish()
749 }
750}
751
752impl DispatchConfig {
753 pub fn new(region: impl Into<String>, account_id: impl Into<String>) -> Self {
756 Self {
757 region: region.into(),
758 account_id: account_id.into(),
759 verify_sigv4: false,
760 iam_mode: IamMode::Off,
761 credential_resolver: None,
762 policy_evaluator: None,
763 resource_policy_provider: None,
764 scp_resolver: None,
765 }
766 }
767}
768
769fn streaming_route(
789 method: &http::Method,
790 path: &str,
791 headers: &http::HeaderMap,
792 query_params: &HashMap<String, String>,
793) -> Option<(&'static str, &'static str)> {
794 if (method == http::Method::PATCH || method == http::Method::PUT)
796 && path.starts_with("/v2/")
797 && path.contains("/blobs/uploads/")
798 {
799 return Some(("ecr", ""));
800 }
801
802 if method == http::Method::PUT {
807 let after = path.trim_start_matches('/');
808 let virtual_hosted_s3 = protocol::parse_routing_host_from_headers(headers)
814 .filter(|h| h.service == "s3" && h.bucket.is_some())
815 .is_some();
816 if after.is_empty() || (!virtual_hosted_s3 && !after.contains('/')) {
817 return None;
818 }
819 let header_s3 = headers
820 .get("authorization")
821 .and_then(|v| v.to_str().ok())
822 .and_then(fakecloud_aws::sigv4::parse_sigv4)
823 .map(|info| info.service == "s3")
824 .unwrap_or(false);
825 let presigned_v4_s3 = query_params
826 .get("X-Amz-Credential")
827 .and_then(|c| c.split('/').nth(3).map(|s| s.to_string()))
828 .map(|service| service == "s3")
829 .unwrap_or(false);
830 let presigned_v2 = query_params.contains_key("AWSAccessKeyId")
831 && query_params.contains_key("Signature")
832 && query_params.contains_key("Expires");
833 if header_s3 || presigned_v4_s3 || presigned_v2 {
834 return Some(("s3", ""));
835 }
836 }
837
838 None
839}
840
841const DEFAULT_MAX_REQUEST_BODY_BYTES: usize = 1024 * 1024 * 1024;
851
852fn max_request_body_bytes() -> usize {
853 static CACHED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
854 *CACHED.get_or_init(|| {
855 std::env::var("FAKECLOUD_MAX_REQUEST_BODY_BYTES")
856 .ok()
857 .and_then(|s| s.parse::<usize>().ok())
858 .filter(|&n| n > 0)
859 .unwrap_or(DEFAULT_MAX_REQUEST_BODY_BYTES)
860 })
861}
862
863fn parse_account_from_arn(arn: &str) -> Option<String> {
869 let mut parts = arn.splitn(6, ':');
870 if parts.next()? != "arn" {
871 return None;
872 }
873 let _partition = parts.next()?;
874 let _service = parts.next()?;
875 let _region = parts.next()?;
876 let account = parts.next()?;
877 parts.next()?;
880 if account.is_empty() {
881 None
882 } else {
883 Some(account.to_string())
884 }
885}
886
887fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
889 let ua = headers.get("user-agent")?.to_str().ok()?;
890 for part in ua.split_whitespace() {
891 if let Some(region) = part.strip_prefix("region/") {
892 if !region.is_empty() {
893 return Some(region.to_string());
894 }
895 }
896 }
897 None
898}
899
900fn build_error_response(
901 status: StatusCode,
902 code: &str,
903 message: &str,
904 request_id: &str,
905 protocol: AwsProtocol,
906) -> Response<Body> {
907 build_error_response_with_fields(status, code, message, request_id, protocol, &[])
908}
909
910fn build_error_response_with_fields(
911 status: StatusCode,
912 code: &str,
913 message: &str,
914 request_id: &str,
915 protocol: AwsProtocol,
916 extra_fields: &[(String, String)],
917) -> Response<Body> {
918 let (status, content_type, body) = match protocol {
919 AwsProtocol::Query => {
920 fakecloud_aws::error::xml_error_response(status, code, message, request_id)
921 }
922 AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
923 status,
924 code,
925 message,
926 request_id,
927 extra_fields,
928 ),
929 AwsProtocol::Json | AwsProtocol::RestJson => {
930 fakecloud_aws::error::json_error_response(status, code, message)
931 }
932 };
933
934 let safe_code = sanitize_header_value(code);
944 let safe_message = sanitize_header_value(message);
945 let mut builder = Response::builder()
946 .status(status)
947 .header("content-type", content_type)
948 .header("x-amzn-requestid", request_id)
949 .header("x-amz-request-id", request_id);
950 if let Ok(v) = http::HeaderValue::from_str(&safe_code) {
951 builder = builder.header("x-amz-error-code", v);
952 }
953 if let Ok(v) = http::HeaderValue::from_str(&safe_message) {
954 builder = builder.header("x-amz-error-message", v);
955 }
956 builder.body(Body::from(body)).unwrap_or_else(|_| {
957 Response::new(Body::empty())
961 })
962}
963
964fn sanitize_header_value(s: &str) -> String {
969 const MAX_LEN: usize = 1024;
970 let mut out = String::with_capacity(s.len().min(MAX_LEN));
971 for ch in s.chars() {
972 if out.len() >= MAX_LEN {
973 break;
974 }
975 if ch.is_control() {
978 if !out.ends_with(' ') {
979 out.push(' ');
980 }
981 } else {
982 out.push(ch);
983 }
984 }
985 out.trim().to_string()
986}
987
988fn build_condition_context(
993 principal: &Principal,
994 remote_addr: Option<SocketAddr>,
995 region: &str,
996 secure_transport: bool,
997) -> ConditionContext {
998 let now = chrono::Utc::now();
999 ConditionContext {
1000 aws_username: aws_username_from_principal(principal),
1001 aws_userid: Some(principal.user_id.clone()),
1002 aws_principal_arn: Some(principal.arn.clone()),
1003 aws_principal_account: Some(principal.account_id.clone()),
1004 aws_principal_type: Some(principal_type_label(principal.principal_type).to_string()),
1005 aws_source_ip: remote_addr.map(|sa| sa.ip()),
1006 aws_current_time: Some(now),
1007 aws_epoch_time: Some(now.timestamp()),
1008 aws_secure_transport: Some(secure_transport),
1009 aws_requested_region: Some(region.to_string()),
1010 aws_mfa_present: None,
1016 aws_mfa_age_seconds: None,
1017 aws_called_via: Vec::new(),
1018 aws_source_vpce: None,
1019 aws_source_vpc: None,
1020 aws_vpc_source_ip: None,
1021 aws_federated_provider: None,
1022 aws_token_issue_time: None,
1023 service_keys: Default::default(),
1024 resource_tags: None,
1025 request_tags: None,
1026 principal_tags: None,
1027 }
1028}
1029
1030fn aws_username_from_principal(principal: &Principal) -> Option<String> {
1034 if principal.principal_type != PrincipalType::User {
1035 return None;
1036 }
1037 let after = principal.arn.rsplit_once(":user/").map(|(_, s)| s)?;
1038 Some(after.rsplit('/').next().unwrap_or(after).to_string())
1040}
1041
1042fn principal_type_label(t: PrincipalType) -> &'static str {
1045 match t {
1046 PrincipalType::User => "User",
1047 PrincipalType::AssumedRole => "AssumedRole",
1048 PrincipalType::FederatedUser => "FederatedUser",
1049 PrincipalType::Root => "Account",
1050 PrincipalType::Unknown => "Unknown",
1051 }
1052}
1053
1054fn is_secure_transport(headers: &http::HeaderMap) -> bool {
1060 headers
1061 .get("x-forwarded-proto")
1062 .and_then(|v| v.to_str().ok())
1063 .map(|s| s.eq_ignore_ascii_case("https"))
1064 .unwrap_or(false)
1065}
1066
1067trait ProtocolExt {
1068 fn error_status(&self) -> StatusCode;
1069}
1070
1071impl ProtocolExt for AwsProtocol {
1072 fn error_status(&self) -> StatusCode {
1073 StatusCode::BAD_REQUEST
1074 }
1075}
1076
1077#[cfg(test)]
1078mod tests {
1079 use super::*;
1080
1081 #[test]
1082 fn default_max_request_body_bytes_is_one_gib() {
1083 assert_eq!(DEFAULT_MAX_REQUEST_BODY_BYTES, 1024 * 1024 * 1024);
1087 }
1088
1089 #[test]
1090 fn dispatch_config_new_defaults_to_off() {
1091 let cfg = DispatchConfig::new("us-east-1", "123456789012");
1092 assert_eq!(cfg.region, "us-east-1");
1093 assert_eq!(cfg.account_id, "123456789012");
1094 assert!(!cfg.verify_sigv4);
1095 assert_eq!(cfg.iam_mode, IamMode::Off);
1096 }
1097
1098 #[test]
1099 fn aws_username_strips_iam_path_for_users() {
1100 let p = Principal {
1101 arn: "arn:aws:iam::123456789012:user/engineering/alice".into(),
1102 user_id: "AIDAALICE".into(),
1103 account_id: "123456789012".into(),
1104 principal_type: PrincipalType::User,
1105 source_identity: None,
1106 tags: None,
1107 };
1108 assert_eq!(aws_username_from_principal(&p), Some("alice".into()));
1109 }
1110
1111 #[test]
1112 fn aws_username_unset_for_assumed_role() {
1113 let p = Principal {
1114 arn: "arn:aws:sts::123456789012:assumed-role/ops/session".into(),
1115 user_id: "AROAOPS:session".into(),
1116 account_id: "123456789012".into(),
1117 principal_type: PrincipalType::AssumedRole,
1118 source_identity: None,
1119 tags: None,
1120 };
1121 assert_eq!(aws_username_from_principal(&p), None);
1122 }
1123
1124 #[test]
1125 fn principal_type_label_matches_aws_casing() {
1126 assert_eq!(principal_type_label(PrincipalType::User), "User");
1127 assert_eq!(
1128 principal_type_label(PrincipalType::AssumedRole),
1129 "AssumedRole"
1130 );
1131 assert_eq!(principal_type_label(PrincipalType::Root), "Account");
1132 }
1133
1134 #[test]
1135 fn build_condition_context_populates_global_keys() {
1136 let p = Principal {
1137 arn: "arn:aws:iam::123456789012:user/alice".into(),
1138 user_id: "AIDAALICE".into(),
1139 account_id: "123456789012".into(),
1140 principal_type: PrincipalType::User,
1141 source_identity: None,
1142 tags: None,
1143 };
1144 let addr: SocketAddr = "10.0.0.1:54321".parse().unwrap();
1145 let ctx = build_condition_context(&p, Some(addr), "us-east-1", false);
1146 assert_eq!(ctx.aws_username.as_deref(), Some("alice"));
1147 assert_eq!(ctx.aws_userid.as_deref(), Some("AIDAALICE"));
1148 assert_eq!(
1149 ctx.aws_principal_arn.as_deref(),
1150 Some("arn:aws:iam::123456789012:user/alice")
1151 );
1152 assert_eq!(ctx.aws_principal_account.as_deref(), Some("123456789012"));
1153 assert_eq!(ctx.aws_principal_type.as_deref(), Some("User"));
1154 assert_eq!(
1155 ctx.aws_source_ip.map(|i| i.to_string()).as_deref(),
1156 Some("10.0.0.1")
1157 );
1158 assert_eq!(ctx.aws_requested_region.as_deref(), Some("us-east-1"));
1159 assert_eq!(ctx.aws_secure_transport, Some(false));
1160 assert!(ctx.aws_current_time.is_some());
1161 assert!(ctx.aws_epoch_time.is_some());
1162 }
1163
1164 #[test]
1165 fn is_secure_transport_reads_x_forwarded_proto() {
1166 let mut headers = http::HeaderMap::new();
1167 headers.insert("x-forwarded-proto", "https".parse().unwrap());
1168 assert!(is_secure_transport(&headers));
1169 headers.insert("x-forwarded-proto", "http".parse().unwrap());
1170 assert!(!is_secure_transport(&headers));
1171 let empty = http::HeaderMap::new();
1172 assert!(!is_secure_transport(&empty));
1173 }
1174
1175 #[test]
1176 fn parse_account_from_arn_extracts_standard_shapes() {
1177 assert_eq!(
1178 parse_account_from_arn("arn:aws:sqs:us-east-1:123456789012:queue"),
1179 Some("123456789012".to_string())
1180 );
1181 assert_eq!(
1182 parse_account_from_arn("arn:aws:iam::123456789012:user/alice"),
1183 Some("123456789012".to_string())
1184 );
1185 }
1186
1187 #[test]
1188 fn parse_account_from_arn_returns_none_for_s3_empty_account() {
1189 assert_eq!(parse_account_from_arn("arn:aws:s3:::my-bucket"), None);
1191 assert_eq!(
1192 parse_account_from_arn("arn:aws:s3:::my-bucket/path/to/key"),
1193 None
1194 );
1195 }
1196
1197 #[test]
1198 fn parse_account_from_arn_returns_none_for_malformed() {
1199 assert_eq!(parse_account_from_arn(""), None);
1200 assert_eq!(parse_account_from_arn("not-an-arn"), None);
1201 assert_eq!(parse_account_from_arn("arn:aws:sqs:us-east-1"), None);
1202 assert_eq!(parse_account_from_arn("arn:aws:sqs"), None);
1203 }
1204
1205 #[test]
1206 fn extract_region_from_user_agent_finds_region_segment() {
1207 let mut headers = http::HeaderMap::new();
1208 headers.insert(
1209 "user-agent",
1210 "aws-sdk-rust/1.0 os/linux region/eu-central-1"
1211 .parse()
1212 .unwrap(),
1213 );
1214 assert_eq!(
1215 extract_region_from_user_agent(&headers),
1216 Some("eu-central-1".to_string())
1217 );
1218 }
1219
1220 #[test]
1221 fn extract_region_from_user_agent_none_without_header() {
1222 let headers = http::HeaderMap::new();
1223 assert_eq!(extract_region_from_user_agent(&headers), None);
1224 }
1225
1226 #[test]
1227 fn extract_region_from_user_agent_ignores_empty_region() {
1228 let mut headers = http::HeaderMap::new();
1229 headers.insert("user-agent", "aws-sdk-java region/".parse().unwrap());
1230 assert_eq!(extract_region_from_user_agent(&headers), None);
1231 }
1232
1233 #[test]
1234 fn extract_region_from_user_agent_none_when_no_region_marker() {
1235 let mut headers = http::HeaderMap::new();
1236 headers.insert("user-agent", "curl/7.79.1".parse().unwrap());
1237 assert_eq!(extract_region_from_user_agent(&headers), None);
1238 }
1239
1240 #[test]
1241 fn aws_username_none_for_root() {
1242 let p = Principal {
1243 arn: "arn:aws:iam::123456789012:root".into(),
1244 user_id: "123456789012".into(),
1245 account_id: "123456789012".into(),
1246 principal_type: PrincipalType::Root,
1247 source_identity: None,
1248 tags: None,
1249 };
1250 assert_eq!(aws_username_from_principal(&p), None);
1251 }
1252
1253 #[test]
1254 fn aws_username_bare_no_path() {
1255 let p = Principal {
1256 arn: "arn:aws:iam::123456789012:user/bob".into(),
1257 user_id: "AIDABOB".into(),
1258 account_id: "123456789012".into(),
1259 principal_type: PrincipalType::User,
1260 source_identity: None,
1261 tags: None,
1262 };
1263 assert_eq!(aws_username_from_principal(&p), Some("bob".into()));
1264 }
1265
1266 #[test]
1267 fn principal_type_label_covers_federated_and_unknown() {
1268 assert_eq!(
1269 principal_type_label(PrincipalType::FederatedUser),
1270 "FederatedUser"
1271 );
1272 assert_eq!(principal_type_label(PrincipalType::Unknown), "Unknown");
1273 }
1274
1275 #[test]
1276 fn build_condition_context_marks_secure_when_flag_set() {
1277 let p = Principal {
1278 arn: "arn:aws:iam::123456789012:user/alice".into(),
1279 user_id: "AIDAALICE".into(),
1280 account_id: "123456789012".into(),
1281 principal_type: PrincipalType::User,
1282 source_identity: None,
1283 tags: None,
1284 };
1285 let ctx = build_condition_context(&p, None, "us-west-2", true);
1286 assert_eq!(ctx.aws_secure_transport, Some(true));
1287 assert!(ctx.aws_source_ip.is_none());
1288 assert_eq!(ctx.aws_requested_region.as_deref(), Some("us-west-2"));
1289 }
1290
1291 #[test]
1292 fn is_secure_transport_case_insensitive() {
1293 let mut headers = http::HeaderMap::new();
1294 headers.insert("x-forwarded-proto", "HTTPS".parse().unwrap());
1295 assert!(is_secure_transport(&headers));
1296 }
1297
1298 #[test]
1299 fn is_secure_transport_non_ascii_bytes_false() {
1300 let mut headers = http::HeaderMap::new();
1301 headers.insert(
1302 "x-forwarded-proto",
1303 http::HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap(),
1304 );
1305 assert!(!is_secure_transport(&headers));
1306 }
1307
1308 #[test]
1309 fn protocol_ext_error_status_is_bad_request() {
1310 assert_eq!(AwsProtocol::Query.error_status(), StatusCode::BAD_REQUEST);
1311 assert_eq!(AwsProtocol::Json.error_status(), StatusCode::BAD_REQUEST);
1312 assert_eq!(AwsProtocol::Rest.error_status(), StatusCode::BAD_REQUEST);
1313 assert_eq!(
1314 AwsProtocol::RestJson.error_status(),
1315 StatusCode::BAD_REQUEST
1316 );
1317 }
1318
1319 #[test]
1320 fn build_error_response_json_has_json_content_type() {
1321 let resp = build_error_response(
1322 StatusCode::BAD_REQUEST,
1323 "TestCode",
1324 "test msg",
1325 "req-1",
1326 AwsProtocol::Json,
1327 );
1328 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
1329 let ct = resp
1330 .headers()
1331 .get("content-type")
1332 .unwrap()
1333 .to_str()
1334 .unwrap();
1335 assert!(ct.contains("json"));
1336 let rid = resp
1337 .headers()
1338 .get("x-amzn-requestid")
1339 .unwrap()
1340 .to_str()
1341 .unwrap();
1342 assert_eq!(rid, "req-1");
1343 }
1344
1345 #[test]
1346 fn build_error_response_rest_returns_xml_content_type() {
1347 let resp = build_error_response(
1348 StatusCode::NOT_FOUND,
1349 "NoSuchBucket",
1350 "bucket missing",
1351 "req-2",
1352 AwsProtocol::Rest,
1353 );
1354 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
1355 let ct = resp
1356 .headers()
1357 .get("content-type")
1358 .unwrap()
1359 .to_str()
1360 .unwrap();
1361 assert!(ct.contains("xml"));
1362 }
1363
1364 #[test]
1365 fn build_error_response_query_returns_xml() {
1366 let resp = build_error_response(
1367 StatusCode::BAD_REQUEST,
1368 "InvalidParameter",
1369 "bad param",
1370 "req-3",
1371 AwsProtocol::Query,
1372 );
1373 let ct = resp
1374 .headers()
1375 .get("content-type")
1376 .unwrap()
1377 .to_str()
1378 .unwrap();
1379 assert!(ct.contains("xml"));
1380 }
1381
1382 #[test]
1387 fn build_error_response_with_multiline_message_does_not_panic() {
1388 let resp = build_error_response(
1389 StatusCode::INTERNAL_SERVER_ERROR,
1390 "ServiceException",
1391 "Lambda execution failed: container failed to start: docker start failed: \
1392 Error: unable to start container \"abc\": \
1393 failed to create new hosts file:\nhost-gateway is empty\n",
1394 "req-multi",
1395 AwsProtocol::Json,
1396 );
1397 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
1398 let msg = resp
1399 .headers()
1400 .get("x-amz-error-message")
1401 .expect("x-amz-error-message must be set even when input contains newlines")
1402 .to_str()
1403 .unwrap();
1404 assert!(!msg.contains('\n'));
1405 assert!(!msg.contains('\r'));
1406 assert!(msg.contains("Lambda execution failed"));
1407 assert!(msg.contains("host-gateway is empty"));
1408 }
1409
1410 #[test]
1411 fn build_error_response_with_control_chars_strips_them() {
1412 let resp = build_error_response(
1413 StatusCode::BAD_REQUEST,
1414 "Code\twith\ttabs",
1415 "msg\x00with\x01nulls",
1416 "req-ctrl",
1417 AwsProtocol::Json,
1418 );
1419 let code = resp
1420 .headers()
1421 .get("x-amz-error-code")
1422 .unwrap()
1423 .to_str()
1424 .unwrap();
1425 let msg = resp
1426 .headers()
1427 .get("x-amz-error-message")
1428 .unwrap()
1429 .to_str()
1430 .unwrap();
1431 assert!(!code.contains('\t'));
1432 assert!(!msg.contains('\x00'));
1433 assert!(!msg.contains('\x01'));
1434 }
1435
1436 #[test]
1437 fn sanitize_header_value_truncates_long_input() {
1438 let huge = "x".repeat(5_000);
1439 let out = sanitize_header_value(&huge);
1440 assert!(out.len() <= 1024);
1441 }
1442
1443 #[test]
1444 fn sanitize_header_value_collapses_consecutive_control_runs() {
1445 let out = sanitize_header_value("a\n\n\n\rb");
1446 assert_eq!(out, "a b");
1447 }
1448
1449 #[test]
1450 fn dispatch_config_carries_opt_in_flags() {
1451 let cfg = DispatchConfig {
1452 region: "eu-west-1".to_string(),
1453 account_id: "000000000000".to_string(),
1454 verify_sigv4: true,
1455 iam_mode: IamMode::Strict,
1456 credential_resolver: None,
1457 policy_evaluator: None,
1458 resource_policy_provider: None,
1459 scp_resolver: None,
1460 };
1461 assert!(cfg.verify_sigv4);
1462 assert!(cfg.iam_mode.is_strict());
1463 assert!(cfg.resource_policy_provider.is_none());
1464 assert!(cfg.scp_resolver.is_none());
1465 }
1466
1467 fn s3_sigv4_headers() -> http::HeaderMap {
1468 let mut headers = http::HeaderMap::new();
1469 headers.insert(
1470 "authorization",
1471 "AWS4-HMAC-SHA256 Credential=test/20240101/us-east-1/s3/aws4_request, \
1472 SignedHeaders=host, Signature=fake"
1473 .parse()
1474 .unwrap(),
1475 );
1476 headers
1477 }
1478
1479 #[test]
1480 fn streaming_route_path_style_s3_put_object() {
1481 let headers = s3_sigv4_headers();
1482 assert_eq!(
1483 streaming_route(
1484 &http::Method::PUT,
1485 "/my-bucket/key.txt",
1486 &headers,
1487 &HashMap::new(),
1488 ),
1489 Some(("s3", "")),
1490 );
1491 }
1492
1493 #[test]
1494 fn streaming_route_path_style_create_bucket_skipped() {
1495 let headers = s3_sigv4_headers();
1498 assert_eq!(
1499 streaming_route(&http::Method::PUT, "/my-bucket", &headers, &HashMap::new(),),
1500 None,
1501 );
1502 }
1503
1504 #[test]
1505 fn streaming_route_virtual_hosted_s3_put_object() {
1506 let mut headers = s3_sigv4_headers();
1507 headers.insert(
1508 "host",
1509 "vhost-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1510 .parse()
1511 .unwrap(),
1512 );
1513 assert_eq!(
1518 streaming_route(&http::Method::PUT, "/hello.txt", &headers, &HashMap::new(),),
1519 Some(("s3", "")),
1520 );
1521 }
1522
1523 #[test]
1524 fn streaming_route_virtual_hosted_s3_root_skipped() {
1525 let mut headers = s3_sigv4_headers();
1528 headers.insert(
1529 "host",
1530 "vhost-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1531 .parse()
1532 .unwrap(),
1533 );
1534 assert_eq!(
1535 streaming_route(&http::Method::PUT, "/", &headers, &HashMap::new()),
1536 None,
1537 );
1538 }
1539
1540 #[test]
1541 fn streaming_route_ecr_blob_upload() {
1542 let headers = http::HeaderMap::new();
1543 assert_eq!(
1544 streaming_route(
1545 &http::Method::PATCH,
1546 "/v2/my-repo/blobs/uploads/abcd1234",
1547 &headers,
1548 &HashMap::new(),
1549 ),
1550 Some(("ecr", "")),
1551 );
1552 assert_eq!(
1553 streaming_route(
1554 &http::Method::PUT,
1555 "/v2/my-repo/blobs/uploads/abcd1234",
1556 &headers,
1557 &HashMap::new(),
1558 ),
1559 Some(("ecr", "")),
1560 );
1561 }
1562
1563 #[test]
1564 fn streaming_route_presigned_v4_s3_put() {
1565 let headers = http::HeaderMap::new();
1566 let mut query_params = HashMap::new();
1567 query_params.insert(
1568 "X-Amz-Credential".to_string(),
1569 "test/20240101/us-east-1/s3/aws4_request".to_string(),
1570 );
1571 assert_eq!(
1572 streaming_route(
1573 &http::Method::PUT,
1574 "/my-bucket/key.txt",
1575 &headers,
1576 &query_params,
1577 ),
1578 Some(("s3", "")),
1579 );
1580 }
1581
1582 #[test]
1583 fn streaming_route_non_s3_auth_header_skipped() {
1584 let mut headers = http::HeaderMap::new();
1587 headers.insert(
1588 "authorization",
1589 "AWS4-HMAC-SHA256 Credential=test/20240101/us-east-1/lambda/aws4_request, \
1590 SignedHeaders=host, Signature=fake"
1591 .parse()
1592 .unwrap(),
1593 );
1594 assert_eq!(
1595 streaming_route(
1596 &http::Method::PUT,
1597 "/my-bucket/key.txt",
1598 &headers,
1599 &HashMap::new(),
1600 ),
1601 None,
1602 );
1603 }
1604
1605 #[test]
1606 fn streaming_route_get_skipped() {
1607 let headers = s3_sigv4_headers();
1608 assert_eq!(
1609 streaming_route(
1610 &http::Method::GET,
1611 "/my-bucket/key.txt",
1612 &headers,
1613 &HashMap::new(),
1614 ),
1615 None,
1616 );
1617 }
1618}