1use axum::body::Body;
2use axum::extract::{ConnectInfo, Extension, Query};
3use axum::http::{Request, StatusCode};
4use axum::response::Response;
5use bytes::Bytes;
6use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9
10use crate::auth::{
11 is_root_bypass, ConditionContext, CredentialResolver, IamMode, IamPolicyEvaluator, Principal,
12 PrincipalType, ResourcePolicyProvider,
13};
14use crate::protocol::{self, AwsProtocol};
15use crate::registry::ServiceRegistry;
16use crate::service::{AwsRequest, ResponseBody};
17
18pub async fn dispatch(
20 ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
21 Extension(registry): Extension<Arc<ServiceRegistry>>,
22 Extension(config): Extension<Arc<DispatchConfig>>,
23 Query(query_params): Query<HashMap<String, String>>,
24 request: Request<Body>,
25) -> Response<Body> {
26 let remote_addr = Some(remote_addr);
27 let request_id = uuid::Uuid::new_v4().to_string();
28
29 let (parts, body) = request.into_parts();
30
31 let stream_route = streaming_route(
37 &parts.method,
38 parts.uri.path(),
39 &parts.headers,
40 &query_params,
41 );
42 let header_only = protocol::detect_service_headers_only(&parts.headers, &query_params);
43 let stream_dispatch = match (&stream_route, &header_only) {
44 (Some(sr), Some(detected)) if sr.0 == detected.service => Some(detected.clone()),
47 (Some((service, _)), None) if *service == "ecr" => Some(protocol::DetectedRequest {
53 service: "ecr".to_string(),
54 action: String::new(),
55 protocol: AwsProtocol::Rest,
56 }),
57 _ => None,
58 };
59
60 let (body_bytes, body_stream) = if stream_dispatch.is_some() {
61 (Bytes::new(), Some(body))
62 } else {
63 let max_body_bytes = max_request_body_bytes();
68 match axum::body::to_bytes(body, max_body_bytes).await {
69 Ok(b) => (b, None),
70 Err(_) => {
71 return build_error_response(
72 StatusCode::PAYLOAD_TOO_LARGE,
73 "RequestEntityTooLarge",
74 "Request body too large",
75 &request_id,
76 AwsProtocol::Query,
77 );
78 }
79 }
80 };
81
82 let detected = if let Some(d) = stream_dispatch {
84 d
85 } else {
86 match protocol::detect_service(&parts.headers, &query_params, &body_bytes) {
87 Some(d) => d,
88 None => {
89 if parts.method == http::Method::OPTIONS {
95 protocol::DetectedRequest {
96 service: "s3".to_string(),
97 action: String::new(),
98 protocol: AwsProtocol::Rest,
99 }
100 } else if parts.uri.path() == "/v2" || parts.uri.path().starts_with("/v2/") {
101 protocol::DetectedRequest {
105 service: "ecr".to_string(),
106 action: String::new(),
107 protocol: AwsProtocol::Rest,
108 }
109 } else if !parts.uri.path().starts_with("/_") {
110 protocol::DetectedRequest {
115 service: "apigateway".to_string(),
116 action: String::new(),
117 protocol: AwsProtocol::RestJson,
118 }
119 } else {
120 return build_error_response(
121 StatusCode::BAD_REQUEST,
122 "MissingAction",
123 "Could not determine target service or action from request",
124 &request_id,
125 AwsProtocol::Query,
126 );
127 }
128 }
129 }
130 };
131
132 let service = match registry.get(&detected.service) {
134 Some(s) => s,
135 None => {
136 return build_error_response(
137 detected.protocol.error_status(),
138 "UnknownService",
139 &format!("Service '{}' is not available", detected.service),
140 &request_id,
141 detected.protocol,
142 );
143 }
144 };
145
146 let auth_header = parts
148 .headers
149 .get("authorization")
150 .and_then(|v| v.to_str().ok())
151 .unwrap_or("");
152 let header_info = fakecloud_aws::sigv4::parse_sigv4(auth_header);
153 let presigned_info = if header_info.is_none() {
154 fakecloud_aws::sigv4::parse_sigv4_presigned(&query_params).map(|p| p.as_info())
156 } else {
157 None
158 };
159 let sigv4_info = header_info.or(presigned_info);
160 let access_key_id = sigv4_info.as_ref().map(|info| info.access_key.clone());
161
162 let host_info = protocol::parse_routing_host_from_headers(&parts.headers);
168
169 let region = sigv4_info
170 .map(|info| info.region)
171 .or_else(|| host_info.as_ref().map(|h| h.region.clone()))
172 .or_else(|| extract_region_from_user_agent(&parts.headers))
173 .unwrap_or_else(|| config.region.clone());
174
175 let caller_akid = access_key_id.as_deref().unwrap_or("");
181 let resolved = if !caller_akid.is_empty() && !is_root_bypass(caller_akid) {
182 config
183 .credential_resolver
184 .as_ref()
185 .and_then(|r| r.resolve(caller_akid))
186 } else {
187 None
188 };
189 let caller_principal = resolved.as_ref().map(|r| r.principal.clone());
190 let caller_session_policies = resolved
191 .as_ref()
192 .map(|r| r.session_policies.clone())
193 .unwrap_or_default();
194
195 if config.verify_sigv4 && !is_root_bypass(caller_akid) && config.credential_resolver.is_some() {
200 let amz_date = parts
201 .headers
202 .get("x-amz-date")
203 .and_then(|v| v.to_str().ok());
204 let parsed = fakecloud_aws::sigv4::parse_sigv4_header(auth_header, amz_date)
205 .or_else(|| fakecloud_aws::sigv4::parse_sigv4_presigned(&query_params));
206 let parsed = match parsed {
207 Some(p) => p,
208 None => {
209 return build_error_response(
210 StatusCode::FORBIDDEN,
211 "IncompleteSignature",
212 "Request is missing or has a malformed AWS Signature",
213 &request_id,
214 detected.protocol,
215 );
216 }
217 };
218 let resolved_for_verify = match resolved.as_ref() {
219 Some(r) => r,
220 None => {
221 return build_error_response(
222 StatusCode::FORBIDDEN,
223 "InvalidClientTokenId",
224 "The security token included in the request is invalid",
225 &request_id,
226 detected.protocol,
227 );
228 }
229 };
230 let headers_vec = fakecloud_aws::sigv4::headers_from_http(&parts.headers);
231 let raw_query_for_verify = parts.uri.query().unwrap_or("").to_string();
232 let verify_req = fakecloud_aws::sigv4::VerifyRequest {
233 method: parts.method.as_str(),
234 path: parts.uri.path(),
235 query: &raw_query_for_verify,
236 headers: &headers_vec,
237 body: &body_bytes,
238 };
239 match fakecloud_aws::sigv4::verify(
240 &parsed,
241 &verify_req,
242 &resolved_for_verify.secret_access_key,
243 chrono::Utc::now(),
244 ) {
245 Ok(()) => {}
246 Err(fakecloud_aws::sigv4::SigV4Error::RequestTimeTooSkewed { .. }) => {
247 return build_error_response(
248 StatusCode::FORBIDDEN,
249 "RequestTimeTooSkewed",
250 "The difference between the request time and the current time is too large",
251 &request_id,
252 detected.protocol,
253 );
254 }
255 Err(fakecloud_aws::sigv4::SigV4Error::InvalidDate(msg)) => {
256 return build_error_response(
257 StatusCode::FORBIDDEN,
258 "IncompleteSignature",
259 &format!("Invalid x-amz-date: {msg}"),
260 &request_id,
261 detected.protocol,
262 );
263 }
264 Err(fakecloud_aws::sigv4::SigV4Error::Malformed(msg)) => {
265 return build_error_response(
266 StatusCode::FORBIDDEN,
267 "IncompleteSignature",
268 &format!("Malformed SigV4 signature: {msg}"),
269 &request_id,
270 detected.protocol,
271 );
272 }
273 Err(fakecloud_aws::sigv4::SigV4Error::SignatureMismatch) => {
274 return build_error_response(
275 StatusCode::FORBIDDEN,
276 "SignatureDoesNotMatch",
277 "The request signature we calculated does not match the signature you provided",
278 &request_id,
279 detected.protocol,
280 );
281 }
282 }
283 }
284
285 let wire_path = parts.uri.path();
290 let path = if detected.service == "s3" {
291 if let Some(bucket) = host_info.as_ref().and_then(|h| h.bucket.as_deref()) {
292 let prefix_with_slash = format!("/{bucket}/");
293 let is_bucket_root = wire_path.trim_end_matches('/') == format!("/{bucket}");
294 if wire_path.starts_with(&prefix_with_slash) || is_bucket_root {
295 wire_path.to_string()
296 } else if wire_path == "/" || wire_path.is_empty() {
297 format!("/{bucket}")
298 } else {
299 format!("/{bucket}{wire_path}")
300 }
301 } else {
302 wire_path.to_string()
303 }
304 } else {
305 wire_path.to_string()
306 };
307 let raw_query = parts.uri.query().unwrap_or("").to_string();
308 let path_segments: Vec<String> = path
309 .split('/')
310 .filter(|s| !s.is_empty())
311 .map(|s| s.to_string())
312 .collect();
313
314 if detected.protocol == AwsProtocol::Json
316 && !body_bytes.is_empty()
317 && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
318 {
319 return build_error_response(
320 StatusCode::BAD_REQUEST,
321 "SerializationException",
322 "Start of structure or map found where not expected",
323 &request_id,
324 AwsProtocol::Json,
325 );
326 }
327
328 let mut all_params = query_params;
330 if detected.protocol == AwsProtocol::Query {
331 let body_params = protocol::parse_query_body(&body_bytes);
332 for (k, v) in body_params {
333 all_params.entry(k).or_insert(v);
334 }
335 }
336
337 let aws_request = AwsRequest {
338 service: detected.service.clone(),
339 action: detected.action.clone(),
340 region,
341 account_id: caller_principal
342 .as_ref()
343 .map(|p| p.account_id.clone())
344 .unwrap_or_else(|| config.account_id.clone()),
345 request_id: request_id.clone(),
346 headers: parts.headers,
347 query_params: all_params,
348 body: body_bytes,
349 body_stream: parking_lot::Mutex::new(body_stream),
350 path_segments,
351 raw_path: path,
352 raw_query,
353 method: parts.method,
354 is_query_protocol: detected.protocol == AwsProtocol::Query,
355 access_key_id,
356 principal: caller_principal,
357 };
358
359 tracing::info!(
360 service = %aws_request.service,
361 action = %aws_request.action,
362 request_id = %aws_request.request_id,
363 "handling request"
364 );
365
366 if config.iam_mode.is_enabled()
373 && service.iam_enforceable()
374 && !is_root_bypass(aws_request.access_key_id.as_deref().unwrap_or(""))
375 {
376 if let Some(evaluator) = config.policy_evaluator.as_ref() {
377 if let Some(principal) = aws_request.principal.as_ref() {
378 if !principal.is_root() {
379 if let Some(iam_action) = service.iam_action_for(&aws_request) {
380 let mut condition_context = build_condition_context(
381 principal,
382 remote_addr,
383 &aws_request.region,
384 is_secure_transport(&aws_request.headers),
385 );
386 condition_context.service_keys =
387 service.iam_condition_keys_for(&aws_request, &iam_action);
388
389 match service.resource_tags_for(&iam_action.resource) {
392 Some(tags) => condition_context.resource_tags = Some(tags),
393 None => tracing::debug!(
394 target: "fakecloud::iam::audit",
395 service = %detected.service,
396 resource = %iam_action.resource,
397 "service does not expose resource tags for ABAC; skipping aws:ResourceTag/* evaluation"
398 ),
399 }
400 match service.request_tags_from(&aws_request, iam_action.action) {
402 Some(tags) => condition_context.request_tags = Some(tags),
403 None => tracing::debug!(
404 target: "fakecloud::iam::audit",
405 service = %detected.service,
406 action = %iam_action.action_string(),
407 "service does not expose request tags for ABAC; skipping aws:RequestTag/* / aws:TagKeys evaluation"
408 ),
409 }
410 condition_context.principal_tags = principal.tags.clone();
412
413 let resource_policy_json =
422 config.resource_policy_provider.as_ref().and_then(|p| {
423 p.resource_policy(&detected.service, &iam_action.resource)
424 });
425 let resource_account_id = parse_account_from_arn(&iam_action.resource)
431 .unwrap_or_else(|| principal.account_id.clone());
432 let scps = config
439 .scp_resolver
440 .as_ref()
441 .and_then(|r| r.scps_for(principal));
442 let decision = evaluator.evaluate_with_resource_policy(
443 principal,
444 &iam_action,
445 &condition_context,
446 resource_policy_json.as_deref(),
447 &resource_account_id,
448 &caller_session_policies,
449 scps.as_deref(),
450 );
451 if !decision.is_allow() {
452 tracing::warn!(
453 target: "fakecloud::iam::audit",
454 service = %detected.service,
455 action = %iam_action.action_string(),
456 resource = %iam_action.resource,
457 principal = %principal.arn,
458 resource_policy_present = resource_policy_json.is_some(),
459 decision = ?decision,
460 mode = %config.iam_mode,
461 request_id = %request_id,
462 "IAM policy evaluation denied request"
463 );
464 if config.iam_mode.is_strict() {
465 return build_error_response(
466 StatusCode::FORBIDDEN,
467 "AccessDeniedException",
468 &format!(
469 "User: {} is not authorized to perform: {} on resource: {}",
470 principal.arn,
471 iam_action.action_string(),
472 iam_action.resource
473 ),
474 &request_id,
475 detected.protocol,
476 );
477 }
478 }
481 } else {
482 tracing::warn!(
487 target: "fakecloud::iam::audit",
488 service = %detected.service,
489 action = %aws_request.action,
490 "service is iam_enforceable but has no IamAction mapping for this action; skipping evaluation"
491 );
492 }
493 }
494 }
495 }
496 }
497
498 match service.handle(aws_request).await {
499 Ok(resp) => {
500 let mut builder = Response::builder()
501 .status(resp.status)
502 .header("x-amzn-requestid", &request_id)
503 .header("x-amz-request-id", &request_id);
504
505 if !resp.content_type.is_empty() {
506 builder = builder.header("content-type", &resp.content_type);
507 }
508
509 let has_content_length = resp
510 .headers
511 .iter()
512 .any(|(k, _)| k.as_str().eq_ignore_ascii_case("content-length"));
513
514 for (k, v) in &resp.headers {
515 builder = builder.header(k, v);
516 }
517
518 match resp.body {
519 ResponseBody::Bytes(b) => builder.body(Body::from(b)).unwrap(),
520 ResponseBody::File { file, size } => {
521 let stream = tokio_util::io::ReaderStream::new(file);
522 let body = Body::from_stream(stream);
523 if !has_content_length {
524 builder = builder.header("content-length", size.to_string());
525 }
526 builder.body(body).unwrap()
527 }
528 }
529 }
530 Err(err) => {
531 tracing::warn!(
532 service = %detected.service,
533 action = %detected.action,
534 error = %err,
535 "request failed"
536 );
537 let error_headers = err.response_headers().to_vec();
538 let mut resp = build_error_response_with_fields(
539 err.status(),
540 err.code(),
541 &err.message(),
542 &request_id,
543 detected.protocol,
544 err.extra_fields(),
545 );
546 for (k, v) in &error_headers {
547 if let (Ok(name), Ok(val)) = (
548 k.parse::<http::header::HeaderName>(),
549 v.parse::<http::header::HeaderValue>(),
550 ) {
551 resp.headers_mut().insert(name, val);
552 }
553 }
554 resp
555 }
556 }
557}
558
559#[derive(Clone)]
561pub struct DispatchConfig {
562 pub region: String,
563 pub account_id: String,
564 pub verify_sigv4: bool,
568 pub iam_mode: IamMode,
573 pub credential_resolver: Option<Arc<dyn CredentialResolver>>,
577 pub policy_evaluator: Option<Arc<dyn IamPolicyEvaluator>>,
581 pub resource_policy_provider: Option<Arc<dyn ResourcePolicyProvider>>,
588 pub scp_resolver: Option<Arc<dyn crate::auth::ScpResolver>>,
595}
596
597impl std::fmt::Debug for DispatchConfig {
598 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
599 f.debug_struct("DispatchConfig")
600 .field("region", &self.region)
601 .field("account_id", &self.account_id)
602 .field("verify_sigv4", &self.verify_sigv4)
603 .field("iam_mode", &self.iam_mode)
604 .field(
605 "credential_resolver",
606 &self
607 .credential_resolver
608 .as_ref()
609 .map(|_| "<CredentialResolver>"),
610 )
611 .field(
612 "policy_evaluator",
613 &self
614 .policy_evaluator
615 .as_ref()
616 .map(|_| "<IamPolicyEvaluator>"),
617 )
618 .field(
619 "resource_policy_provider",
620 &self
621 .resource_policy_provider
622 .as_ref()
623 .map(|_| "<ResourcePolicyProvider>"),
624 )
625 .field(
626 "scp_resolver",
627 &self.scp_resolver.as_ref().map(|_| "<ScpResolver>"),
628 )
629 .finish()
630 }
631}
632
633impl DispatchConfig {
634 pub fn new(region: impl Into<String>, account_id: impl Into<String>) -> Self {
637 Self {
638 region: region.into(),
639 account_id: account_id.into(),
640 verify_sigv4: false,
641 iam_mode: IamMode::Off,
642 credential_resolver: None,
643 policy_evaluator: None,
644 resource_policy_provider: None,
645 scp_resolver: None,
646 }
647 }
648}
649
650fn streaming_route(
670 method: &http::Method,
671 path: &str,
672 headers: &http::HeaderMap,
673 query_params: &HashMap<String, String>,
674) -> Option<(&'static str, &'static str)> {
675 if (method == http::Method::PATCH || method == http::Method::PUT)
677 && path.starts_with("/v2/")
678 && path.contains("/blobs/uploads/")
679 {
680 return Some(("ecr", ""));
681 }
682
683 if method == http::Method::PUT {
688 let after = path.trim_start_matches('/');
689 let virtual_hosted_s3 = protocol::parse_routing_host_from_headers(headers)
695 .filter(|h| h.service == "s3" && h.bucket.is_some())
696 .is_some();
697 if after.is_empty() || (!virtual_hosted_s3 && !after.contains('/')) {
698 return None;
699 }
700 let header_s3 = headers
701 .get("authorization")
702 .and_then(|v| v.to_str().ok())
703 .and_then(fakecloud_aws::sigv4::parse_sigv4)
704 .map(|info| info.service == "s3")
705 .unwrap_or(false);
706 let presigned_v4_s3 = query_params
707 .get("X-Amz-Credential")
708 .and_then(|c| c.split('/').nth(3).map(|s| s.to_string()))
709 .map(|service| service == "s3")
710 .unwrap_or(false);
711 let presigned_v2 = query_params.contains_key("AWSAccessKeyId")
712 && query_params.contains_key("Signature")
713 && query_params.contains_key("Expires");
714 if header_s3 || presigned_v4_s3 || presigned_v2 {
715 return Some(("s3", ""));
716 }
717 }
718
719 None
720}
721
722const DEFAULT_MAX_REQUEST_BODY_BYTES: usize = 1024 * 1024 * 1024;
732
733fn max_request_body_bytes() -> usize {
734 static CACHED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
735 *CACHED.get_or_init(|| {
736 std::env::var("FAKECLOUD_MAX_REQUEST_BODY_BYTES")
737 .ok()
738 .and_then(|s| s.parse::<usize>().ok())
739 .filter(|&n| n > 0)
740 .unwrap_or(DEFAULT_MAX_REQUEST_BODY_BYTES)
741 })
742}
743
744fn parse_account_from_arn(arn: &str) -> Option<String> {
750 let mut parts = arn.splitn(6, ':');
751 if parts.next()? != "arn" {
752 return None;
753 }
754 let _partition = parts.next()?;
755 let _service = parts.next()?;
756 let _region = parts.next()?;
757 let account = parts.next()?;
758 parts.next()?;
761 if account.is_empty() {
762 None
763 } else {
764 Some(account.to_string())
765 }
766}
767
768fn extract_region_from_user_agent(headers: &http::HeaderMap) -> Option<String> {
770 let ua = headers.get("user-agent")?.to_str().ok()?;
771 for part in ua.split_whitespace() {
772 if let Some(region) = part.strip_prefix("region/") {
773 if !region.is_empty() {
774 return Some(region.to_string());
775 }
776 }
777 }
778 None
779}
780
781fn build_error_response(
782 status: StatusCode,
783 code: &str,
784 message: &str,
785 request_id: &str,
786 protocol: AwsProtocol,
787) -> Response<Body> {
788 build_error_response_with_fields(status, code, message, request_id, protocol, &[])
789}
790
791fn build_error_response_with_fields(
792 status: StatusCode,
793 code: &str,
794 message: &str,
795 request_id: &str,
796 protocol: AwsProtocol,
797 extra_fields: &[(String, String)],
798) -> Response<Body> {
799 let (status, content_type, body) = match protocol {
800 AwsProtocol::Query => {
801 fakecloud_aws::error::xml_error_response(status, code, message, request_id)
802 }
803 AwsProtocol::Rest => fakecloud_aws::error::s3_xml_error_response_with_fields(
804 status,
805 code,
806 message,
807 request_id,
808 extra_fields,
809 ),
810 AwsProtocol::Json | AwsProtocol::RestJson => {
811 fakecloud_aws::error::json_error_response(status, code, message)
812 }
813 };
814
815 Response::builder()
816 .status(status)
817 .header("content-type", content_type)
818 .header("x-amzn-requestid", request_id)
819 .header("x-amz-request-id", request_id)
820 .body(Body::from(body))
821 .unwrap()
822}
823
824fn build_condition_context(
829 principal: &Principal,
830 remote_addr: Option<SocketAddr>,
831 region: &str,
832 secure_transport: bool,
833) -> ConditionContext {
834 let now = chrono::Utc::now();
835 ConditionContext {
836 aws_username: aws_username_from_principal(principal),
837 aws_userid: Some(principal.user_id.clone()),
838 aws_principal_arn: Some(principal.arn.clone()),
839 aws_principal_account: Some(principal.account_id.clone()),
840 aws_principal_type: Some(principal_type_label(principal.principal_type).to_string()),
841 aws_source_ip: remote_addr.map(|sa| sa.ip()),
842 aws_current_time: Some(now),
843 aws_epoch_time: Some(now.timestamp()),
844 aws_secure_transport: Some(secure_transport),
845 aws_requested_region: Some(region.to_string()),
846 service_keys: Default::default(),
847 resource_tags: None,
848 request_tags: None,
849 principal_tags: None,
850 }
851}
852
853fn aws_username_from_principal(principal: &Principal) -> Option<String> {
857 if principal.principal_type != PrincipalType::User {
858 return None;
859 }
860 let after = principal.arn.rsplit_once(":user/").map(|(_, s)| s)?;
861 Some(after.rsplit('/').next().unwrap_or(after).to_string())
863}
864
865fn principal_type_label(t: PrincipalType) -> &'static str {
868 match t {
869 PrincipalType::User => "User",
870 PrincipalType::AssumedRole => "AssumedRole",
871 PrincipalType::FederatedUser => "FederatedUser",
872 PrincipalType::Root => "Account",
873 PrincipalType::Unknown => "Unknown",
874 }
875}
876
877fn is_secure_transport(headers: &http::HeaderMap) -> bool {
883 headers
884 .get("x-forwarded-proto")
885 .and_then(|v| v.to_str().ok())
886 .map(|s| s.eq_ignore_ascii_case("https"))
887 .unwrap_or(false)
888}
889
890trait ProtocolExt {
891 fn error_status(&self) -> StatusCode;
892}
893
894impl ProtocolExt for AwsProtocol {
895 fn error_status(&self) -> StatusCode {
896 StatusCode::BAD_REQUEST
897 }
898}
899
900#[cfg(test)]
901mod tests {
902 use super::*;
903
904 #[test]
905 fn default_max_request_body_bytes_is_one_gib() {
906 assert_eq!(DEFAULT_MAX_REQUEST_BODY_BYTES, 1024 * 1024 * 1024);
910 }
911
912 #[test]
913 fn dispatch_config_new_defaults_to_off() {
914 let cfg = DispatchConfig::new("us-east-1", "123456789012");
915 assert_eq!(cfg.region, "us-east-1");
916 assert_eq!(cfg.account_id, "123456789012");
917 assert!(!cfg.verify_sigv4);
918 assert_eq!(cfg.iam_mode, IamMode::Off);
919 }
920
921 #[test]
922 fn aws_username_strips_iam_path_for_users() {
923 let p = Principal {
924 arn: "arn:aws:iam::123456789012:user/engineering/alice".into(),
925 user_id: "AIDAALICE".into(),
926 account_id: "123456789012".into(),
927 principal_type: PrincipalType::User,
928 source_identity: None,
929 tags: None,
930 };
931 assert_eq!(aws_username_from_principal(&p), Some("alice".into()));
932 }
933
934 #[test]
935 fn aws_username_unset_for_assumed_role() {
936 let p = Principal {
937 arn: "arn:aws:sts::123456789012:assumed-role/ops/session".into(),
938 user_id: "AROAOPS:session".into(),
939 account_id: "123456789012".into(),
940 principal_type: PrincipalType::AssumedRole,
941 source_identity: None,
942 tags: None,
943 };
944 assert_eq!(aws_username_from_principal(&p), None);
945 }
946
947 #[test]
948 fn principal_type_label_matches_aws_casing() {
949 assert_eq!(principal_type_label(PrincipalType::User), "User");
950 assert_eq!(
951 principal_type_label(PrincipalType::AssumedRole),
952 "AssumedRole"
953 );
954 assert_eq!(principal_type_label(PrincipalType::Root), "Account");
955 }
956
957 #[test]
958 fn build_condition_context_populates_global_keys() {
959 let p = Principal {
960 arn: "arn:aws:iam::123456789012:user/alice".into(),
961 user_id: "AIDAALICE".into(),
962 account_id: "123456789012".into(),
963 principal_type: PrincipalType::User,
964 source_identity: None,
965 tags: None,
966 };
967 let addr: SocketAddr = "10.0.0.1:54321".parse().unwrap();
968 let ctx = build_condition_context(&p, Some(addr), "us-east-1", false);
969 assert_eq!(ctx.aws_username.as_deref(), Some("alice"));
970 assert_eq!(ctx.aws_userid.as_deref(), Some("AIDAALICE"));
971 assert_eq!(
972 ctx.aws_principal_arn.as_deref(),
973 Some("arn:aws:iam::123456789012:user/alice")
974 );
975 assert_eq!(ctx.aws_principal_account.as_deref(), Some("123456789012"));
976 assert_eq!(ctx.aws_principal_type.as_deref(), Some("User"));
977 assert_eq!(
978 ctx.aws_source_ip.map(|i| i.to_string()).as_deref(),
979 Some("10.0.0.1")
980 );
981 assert_eq!(ctx.aws_requested_region.as_deref(), Some("us-east-1"));
982 assert_eq!(ctx.aws_secure_transport, Some(false));
983 assert!(ctx.aws_current_time.is_some());
984 assert!(ctx.aws_epoch_time.is_some());
985 }
986
987 #[test]
988 fn is_secure_transport_reads_x_forwarded_proto() {
989 let mut headers = http::HeaderMap::new();
990 headers.insert("x-forwarded-proto", "https".parse().unwrap());
991 assert!(is_secure_transport(&headers));
992 headers.insert("x-forwarded-proto", "http".parse().unwrap());
993 assert!(!is_secure_transport(&headers));
994 let empty = http::HeaderMap::new();
995 assert!(!is_secure_transport(&empty));
996 }
997
998 #[test]
999 fn parse_account_from_arn_extracts_standard_shapes() {
1000 assert_eq!(
1001 parse_account_from_arn("arn:aws:sqs:us-east-1:123456789012:queue"),
1002 Some("123456789012".to_string())
1003 );
1004 assert_eq!(
1005 parse_account_from_arn("arn:aws:iam::123456789012:user/alice"),
1006 Some("123456789012".to_string())
1007 );
1008 }
1009
1010 #[test]
1011 fn parse_account_from_arn_returns_none_for_s3_empty_account() {
1012 assert_eq!(parse_account_from_arn("arn:aws:s3:::my-bucket"), None);
1014 assert_eq!(
1015 parse_account_from_arn("arn:aws:s3:::my-bucket/path/to/key"),
1016 None
1017 );
1018 }
1019
1020 #[test]
1021 fn parse_account_from_arn_returns_none_for_malformed() {
1022 assert_eq!(parse_account_from_arn(""), None);
1023 assert_eq!(parse_account_from_arn("not-an-arn"), None);
1024 assert_eq!(parse_account_from_arn("arn:aws:sqs:us-east-1"), None);
1025 assert_eq!(parse_account_from_arn("arn:aws:sqs"), None);
1026 }
1027
1028 #[test]
1029 fn extract_region_from_user_agent_finds_region_segment() {
1030 let mut headers = http::HeaderMap::new();
1031 headers.insert(
1032 "user-agent",
1033 "aws-sdk-rust/1.0 os/linux region/eu-central-1"
1034 .parse()
1035 .unwrap(),
1036 );
1037 assert_eq!(
1038 extract_region_from_user_agent(&headers),
1039 Some("eu-central-1".to_string())
1040 );
1041 }
1042
1043 #[test]
1044 fn extract_region_from_user_agent_none_without_header() {
1045 let headers = http::HeaderMap::new();
1046 assert_eq!(extract_region_from_user_agent(&headers), None);
1047 }
1048
1049 #[test]
1050 fn extract_region_from_user_agent_ignores_empty_region() {
1051 let mut headers = http::HeaderMap::new();
1052 headers.insert("user-agent", "aws-sdk-java region/".parse().unwrap());
1053 assert_eq!(extract_region_from_user_agent(&headers), None);
1054 }
1055
1056 #[test]
1057 fn extract_region_from_user_agent_none_when_no_region_marker() {
1058 let mut headers = http::HeaderMap::new();
1059 headers.insert("user-agent", "curl/7.79.1".parse().unwrap());
1060 assert_eq!(extract_region_from_user_agent(&headers), None);
1061 }
1062
1063 #[test]
1064 fn aws_username_none_for_root() {
1065 let p = Principal {
1066 arn: "arn:aws:iam::123456789012:root".into(),
1067 user_id: "123456789012".into(),
1068 account_id: "123456789012".into(),
1069 principal_type: PrincipalType::Root,
1070 source_identity: None,
1071 tags: None,
1072 };
1073 assert_eq!(aws_username_from_principal(&p), None);
1074 }
1075
1076 #[test]
1077 fn aws_username_bare_no_path() {
1078 let p = Principal {
1079 arn: "arn:aws:iam::123456789012:user/bob".into(),
1080 user_id: "AIDABOB".into(),
1081 account_id: "123456789012".into(),
1082 principal_type: PrincipalType::User,
1083 source_identity: None,
1084 tags: None,
1085 };
1086 assert_eq!(aws_username_from_principal(&p), Some("bob".into()));
1087 }
1088
1089 #[test]
1090 fn principal_type_label_covers_federated_and_unknown() {
1091 assert_eq!(
1092 principal_type_label(PrincipalType::FederatedUser),
1093 "FederatedUser"
1094 );
1095 assert_eq!(principal_type_label(PrincipalType::Unknown), "Unknown");
1096 }
1097
1098 #[test]
1099 fn build_condition_context_marks_secure_when_flag_set() {
1100 let p = Principal {
1101 arn: "arn:aws:iam::123456789012:user/alice".into(),
1102 user_id: "AIDAALICE".into(),
1103 account_id: "123456789012".into(),
1104 principal_type: PrincipalType::User,
1105 source_identity: None,
1106 tags: None,
1107 };
1108 let ctx = build_condition_context(&p, None, "us-west-2", true);
1109 assert_eq!(ctx.aws_secure_transport, Some(true));
1110 assert!(ctx.aws_source_ip.is_none());
1111 assert_eq!(ctx.aws_requested_region.as_deref(), Some("us-west-2"));
1112 }
1113
1114 #[test]
1115 fn is_secure_transport_case_insensitive() {
1116 let mut headers = http::HeaderMap::new();
1117 headers.insert("x-forwarded-proto", "HTTPS".parse().unwrap());
1118 assert!(is_secure_transport(&headers));
1119 }
1120
1121 #[test]
1122 fn is_secure_transport_non_ascii_bytes_false() {
1123 let mut headers = http::HeaderMap::new();
1124 headers.insert(
1125 "x-forwarded-proto",
1126 http::HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap(),
1127 );
1128 assert!(!is_secure_transport(&headers));
1129 }
1130
1131 #[test]
1132 fn protocol_ext_error_status_is_bad_request() {
1133 assert_eq!(AwsProtocol::Query.error_status(), StatusCode::BAD_REQUEST);
1134 assert_eq!(AwsProtocol::Json.error_status(), StatusCode::BAD_REQUEST);
1135 assert_eq!(AwsProtocol::Rest.error_status(), StatusCode::BAD_REQUEST);
1136 assert_eq!(
1137 AwsProtocol::RestJson.error_status(),
1138 StatusCode::BAD_REQUEST
1139 );
1140 }
1141
1142 #[test]
1143 fn build_error_response_json_has_json_content_type() {
1144 let resp = build_error_response(
1145 StatusCode::BAD_REQUEST,
1146 "TestCode",
1147 "test msg",
1148 "req-1",
1149 AwsProtocol::Json,
1150 );
1151 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
1152 let ct = resp
1153 .headers()
1154 .get("content-type")
1155 .unwrap()
1156 .to_str()
1157 .unwrap();
1158 assert!(ct.contains("json"));
1159 let rid = resp
1160 .headers()
1161 .get("x-amzn-requestid")
1162 .unwrap()
1163 .to_str()
1164 .unwrap();
1165 assert_eq!(rid, "req-1");
1166 }
1167
1168 #[test]
1169 fn build_error_response_rest_returns_xml_content_type() {
1170 let resp = build_error_response(
1171 StatusCode::NOT_FOUND,
1172 "NoSuchBucket",
1173 "bucket missing",
1174 "req-2",
1175 AwsProtocol::Rest,
1176 );
1177 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
1178 let ct = resp
1179 .headers()
1180 .get("content-type")
1181 .unwrap()
1182 .to_str()
1183 .unwrap();
1184 assert!(ct.contains("xml"));
1185 }
1186
1187 #[test]
1188 fn build_error_response_query_returns_xml() {
1189 let resp = build_error_response(
1190 StatusCode::BAD_REQUEST,
1191 "InvalidParameter",
1192 "bad param",
1193 "req-3",
1194 AwsProtocol::Query,
1195 );
1196 let ct = resp
1197 .headers()
1198 .get("content-type")
1199 .unwrap()
1200 .to_str()
1201 .unwrap();
1202 assert!(ct.contains("xml"));
1203 }
1204
1205 #[test]
1206 fn dispatch_config_carries_opt_in_flags() {
1207 let cfg = DispatchConfig {
1208 region: "eu-west-1".to_string(),
1209 account_id: "000000000000".to_string(),
1210 verify_sigv4: true,
1211 iam_mode: IamMode::Strict,
1212 credential_resolver: None,
1213 policy_evaluator: None,
1214 resource_policy_provider: None,
1215 scp_resolver: None,
1216 };
1217 assert!(cfg.verify_sigv4);
1218 assert!(cfg.iam_mode.is_strict());
1219 assert!(cfg.resource_policy_provider.is_none());
1220 assert!(cfg.scp_resolver.is_none());
1221 }
1222
1223 fn s3_sigv4_headers() -> http::HeaderMap {
1224 let mut headers = http::HeaderMap::new();
1225 headers.insert(
1226 "authorization",
1227 "AWS4-HMAC-SHA256 Credential=test/20240101/us-east-1/s3/aws4_request, \
1228 SignedHeaders=host, Signature=fake"
1229 .parse()
1230 .unwrap(),
1231 );
1232 headers
1233 }
1234
1235 #[test]
1236 fn streaming_route_path_style_s3_put_object() {
1237 let headers = s3_sigv4_headers();
1238 assert_eq!(
1239 streaming_route(
1240 &http::Method::PUT,
1241 "/my-bucket/key.txt",
1242 &headers,
1243 &HashMap::new(),
1244 ),
1245 Some(("s3", "")),
1246 );
1247 }
1248
1249 #[test]
1250 fn streaming_route_path_style_create_bucket_skipped() {
1251 let headers = s3_sigv4_headers();
1254 assert_eq!(
1255 streaming_route(&http::Method::PUT, "/my-bucket", &headers, &HashMap::new(),),
1256 None,
1257 );
1258 }
1259
1260 #[test]
1261 fn streaming_route_virtual_hosted_s3_put_object() {
1262 let mut headers = s3_sigv4_headers();
1263 headers.insert(
1264 "host",
1265 "vhost-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1266 .parse()
1267 .unwrap(),
1268 );
1269 assert_eq!(
1274 streaming_route(&http::Method::PUT, "/hello.txt", &headers, &HashMap::new(),),
1275 Some(("s3", "")),
1276 );
1277 }
1278
1279 #[test]
1280 fn streaming_route_virtual_hosted_s3_root_skipped() {
1281 let mut headers = s3_sigv4_headers();
1284 headers.insert(
1285 "host",
1286 "vhost-bucket.s3.us-east-1.localhost.localstack.cloud:4566"
1287 .parse()
1288 .unwrap(),
1289 );
1290 assert_eq!(
1291 streaming_route(&http::Method::PUT, "/", &headers, &HashMap::new()),
1292 None,
1293 );
1294 }
1295
1296 #[test]
1297 fn streaming_route_ecr_blob_upload() {
1298 let headers = http::HeaderMap::new();
1299 assert_eq!(
1300 streaming_route(
1301 &http::Method::PATCH,
1302 "/v2/my-repo/blobs/uploads/abcd1234",
1303 &headers,
1304 &HashMap::new(),
1305 ),
1306 Some(("ecr", "")),
1307 );
1308 assert_eq!(
1309 streaming_route(
1310 &http::Method::PUT,
1311 "/v2/my-repo/blobs/uploads/abcd1234",
1312 &headers,
1313 &HashMap::new(),
1314 ),
1315 Some(("ecr", "")),
1316 );
1317 }
1318
1319 #[test]
1320 fn streaming_route_presigned_v4_s3_put() {
1321 let headers = http::HeaderMap::new();
1322 let mut query_params = HashMap::new();
1323 query_params.insert(
1324 "X-Amz-Credential".to_string(),
1325 "test/20240101/us-east-1/s3/aws4_request".to_string(),
1326 );
1327 assert_eq!(
1328 streaming_route(
1329 &http::Method::PUT,
1330 "/my-bucket/key.txt",
1331 &headers,
1332 &query_params,
1333 ),
1334 Some(("s3", "")),
1335 );
1336 }
1337
1338 #[test]
1339 fn streaming_route_non_s3_auth_header_skipped() {
1340 let mut headers = http::HeaderMap::new();
1343 headers.insert(
1344 "authorization",
1345 "AWS4-HMAC-SHA256 Credential=test/20240101/us-east-1/lambda/aws4_request, \
1346 SignedHeaders=host, Signature=fake"
1347 .parse()
1348 .unwrap(),
1349 );
1350 assert_eq!(
1351 streaming_route(
1352 &http::Method::PUT,
1353 "/my-bucket/key.txt",
1354 &headers,
1355 &HashMap::new(),
1356 ),
1357 None,
1358 );
1359 }
1360
1361 #[test]
1362 fn streaming_route_get_skipped() {
1363 let headers = s3_sigv4_headers();
1364 assert_eq!(
1365 streaming_route(
1366 &http::Method::GET,
1367 "/my-bucket/key.txt",
1368 &headers,
1369 &HashMap::new(),
1370 ),
1371 None,
1372 );
1373 }
1374}