use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use axum::body::Body;
use axum::extract::State;
use axum::http::{HeaderMap, Method, Response, StatusCode, Uri};
use bytes::Bytes;
use tracing::{debug, info, warn};
use crate::ServiceHandler;
use crate::auth;
use crate::authz::AuthzEngine;
use crate::body_store::BodyStore;
use crate::error::AwsError;
use crate::events::EventBus;
use crate::protocol::{self, Protocol, RouteDefinition};
use crate::request_detail::{
CapturedBody, RequestDetail, RequestDetailStore, capture_body, capture_headers,
};
use crate::request_event::{RequestEvent, RequestEventBus};
#[derive(Clone)]
pub struct BodyStoreHandle {
pub service_name: String,
pub groups: Vec<String>,
pub body_store: Arc<BodyStore>,
}
#[derive(Clone)]
pub struct AppState {
pub services: Arc<HashMap<String, Arc<dyn ServiceHandler>>>,
pub routes: Arc<HashMap<String, Vec<RouteDefinition>>>,
pub default_region: String,
pub default_account_id: String,
pub default_partition: String,
pub event_bus: EventBus,
pub request_count: Arc<AtomicU64>,
pub start_time: std::time::Instant,
pub authz: Arc<AuthzEngine>,
pub body_stores: Arc<Vec<BodyStoreHandle>>,
pub data_dir: Option<Arc<std::path::PathBuf>>,
pub events: RequestEventBus,
pub request_details: RequestDetailStore,
pub chaos: Arc<awsim_chaos::ChaosEngine>,
pub workers: crate::tick::WorkerPool,
}
impl AppState {
pub fn new(default_region: String, default_account_id: String) -> Self {
Self::with_partition(
default_region,
default_account_id,
crate::router::DEFAULT_PARTITION.to_string(),
)
}
pub fn with_partition(
default_region: String,
default_account_id: String,
default_partition: String,
) -> Self {
Self {
services: Arc::new(HashMap::new()),
routes: Arc::new(HashMap::new()),
default_region,
default_account_id,
default_partition,
event_bus: EventBus::new(),
request_count: Arc::new(AtomicU64::new(0)),
start_time: std::time::Instant::now(),
authz: Arc::new(AuthzEngine::from_env()),
body_stores: Arc::new(Vec::new()),
data_dir: None,
events: RequestEventBus::new(),
request_details: RequestDetailStore::default(),
chaos: Arc::new(awsim_chaos::ChaosEngine::new()),
workers: crate::tick::WorkerPool::new(),
}
}
pub fn register(&mut self, handler: Arc<dyn ServiceHandler>, routes: Vec<RouteDefinition>) {
let signing_name = handler.signing_name().to_string();
let service_name = handler.service_name().to_string();
info!(
service = %service_name,
signing_name = %signing_name,
protocol = ?handler.protocol(),
routes = routes.len(),
"Registered service"
);
Arc::get_mut(&mut self.services)
.unwrap()
.insert(signing_name.clone(), handler);
if !routes.is_empty() {
Arc::get_mut(&mut self.routes)
.unwrap()
.insert(signing_name, routes);
}
}
}
struct ProcessOk {
status: StatusCode,
headers: HeaderMap,
body: ProcessBody,
operation: String,
}
enum ProcessBody {
Bytes(Bytes),
Stream(crate::HandlerByteStream),
}
impl ProcessBody {
fn buffered_len(&self) -> Option<usize> {
match self {
ProcessBody::Bytes(b) => Some(b.len()),
ProcessBody::Stream(_) => None,
}
}
}
struct ProcessMeta {
service: String,
region: String,
account_id: String,
access_key: Option<String>,
}
fn require_signed_requests_enabled() -> bool {
use std::sync::OnceLock;
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("AWSIM_REQUIRE_SIGNED_REQUESTS")
.map(|v| matches!(v.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
.unwrap_or(false)
})
}
pub fn spawn_tick_loop(
state: AppState,
interval: std::time::Duration,
) -> tokio::task::JoinHandle<()> {
use futures::FutureExt;
use std::panic::AssertUnwindSafe;
const PER_SERVICE_TICK_DEADLINE: std::time::Duration = std::time::Duration::from_millis(50);
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
ticker.tick().await; loop {
ticker.tick().await;
let services: Vec<(String, Arc<dyn ServiceHandler>)> = state
.services
.iter()
.map(|(name, svc)| (name.clone(), svc.clone()))
.collect();
for (name, svc) in services {
let tick_fut = AssertUnwindSafe(svc.tick()).catch_unwind();
match tokio::time::timeout(PER_SERVICE_TICK_DEADLINE, tick_fut).await {
Ok(Ok(())) => {}
Ok(Err(panic)) => {
let msg = panic
.downcast_ref::<String>()
.cloned()
.or_else(|| panic.downcast_ref::<&'static str>().map(|s| s.to_string()))
.unwrap_or_else(|| "<non-string panic payload>".to_string());
warn!(service = %name, panic = %msg, "service tick panicked");
}
Err(_) => {
warn!(
service = %name,
budget_ms = PER_SERVICE_TICK_DEADLINE.as_millis() as u64,
"service tick exceeded budget; consider moving slow work to AppState::workers"
);
}
}
}
}
})
}
pub async fn handle_request(
State(state): State<AppState>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Bytes,
) -> Response<Body> {
if is_browser_probe(&method, uri.path()) {
return Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Body::empty())
.unwrap();
}
let (response, _id) = dispatch_request(&state, method, uri, headers, body).await;
response
}
fn is_browser_probe(method: &Method, path: &str) -> bool {
if method != Method::GET {
return false;
}
matches!(
path,
"/favicon.ico"
| "/apple-touch-icon.png"
| "/apple-touch-icon-precomposed.png"
| "/robots.txt"
| "/.well-known/appspecific/com.chrome.devtools.json"
)
}
pub async fn dispatch_request(
state: &AppState,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Bytes,
) -> (Response<Body>, String) {
state.request_count.fetch_add(1, Ordering::Relaxed);
let request_id = uuid::Uuid::new_v4().to_string();
let started = Instant::now();
let request_size = body.len() as u64;
debug!(
method = %method,
uri = %uri,
request_id = %request_id,
"Incoming request"
);
let mut meta = ProcessMeta {
service: String::new(),
region: state.default_region.clone(),
account_id: state.default_account_id.clone(),
access_key: None,
};
let outcome = process_request(
state,
&method,
&uri,
&headers,
&body,
&request_id,
&mut meta,
)
.await;
let (status, mut resp_headers, resp_body, operation, error_code) = match outcome {
Ok(ProcessOk {
status,
headers,
body,
operation,
}) => (status, headers, body, Some(operation), None),
Err((protocol, error)) => {
warn!(
error_code = %error.code,
error_message = %error.message,
request_id = %request_id,
"Request failed"
);
let err_code = error.code.clone();
let (status, resp_headers, resp_body) =
protocol::serialize_error(protocol, &error, &request_id);
(
status,
resp_headers,
ProcessBody::Bytes(resp_body),
None,
Some(err_code),
)
}
};
let status_code = status.as_u16();
let response_size = resp_body.buffered_len().unwrap_or(0) as u64;
let body_cap = state.request_details.body_cap();
let captured_response = match &resp_body {
ProcessBody::Bytes(b) => capture_body(b, body_cap),
ProcessBody::Stream(_) => CapturedBody::placeholder("<streaming response>"),
};
let detail = RequestDetail {
id: request_id.clone(),
method: method.to_string(),
path: uri.path().to_string(),
query: uri.query().map(|q| q.to_string()),
status_code,
request_headers: capture_headers(&headers),
response_headers: capture_headers(&resp_headers),
request_body: capture_body(&body, body_cap),
response_body: captured_response,
};
state.request_details.insert(detail);
let memory_mb = resp_headers
.remove("x-awsim-memory-mb")
.and_then(|v| v.to_str().ok().and_then(|s| s.parse::<u32>().ok()));
let state_transitions = resp_headers
.remove("x-awsim-state-transitions")
.and_then(|v| v.to_str().ok().and_then(|s| s.parse::<u32>().ok()));
let character_count = resp_headers
.remove("x-awsim-char-count")
.and_then(|v| v.to_str().ok().and_then(|s| s.parse::<u64>().ok()));
let mut builder = Response::builder().status(status);
for (key, value) in resp_headers.drain() {
if let Some(key) = key {
builder = builder.header(key, value);
}
}
let body_for_response = match resp_body {
ProcessBody::Bytes(b) => Body::from(b),
ProcessBody::Stream(s) => {
use futures::StreamExt;
let mapped = s.map(|res| match res {
Ok(b) => Ok::<_, std::io::Error>(b),
Err(e) => {
let payload = format!("{{\"error\":\"{}\"}}", e.message);
Ok(Bytes::from(payload))
}
});
Body::from_stream(mapped)
}
};
let response = builder.body(body_for_response).unwrap();
let duration_ms = started.elapsed().as_secs_f64() * 1000.0;
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.unwrap_or(0.0);
let principal_arn = meta
.access_key
.as_ref()
.map(|ak| format!("arn:aws:iam::{}:access-key/{}", meta.account_id, ak));
let event = RequestEvent {
id: request_id.clone(),
ts,
method: method.to_string(),
path: uri.path().to_string(),
service: meta.service.clone(),
operation: operation.clone(),
account_id: meta.account_id.clone(),
region: meta.region.clone(),
principal_arn: principal_arn.clone(),
status_code,
duration_ms,
request_size,
response_size,
error_code: error_code.clone(),
memory_mb,
state_transitions,
character_count,
};
state.events.publish(event);
if !meta.service.is_empty() {
let user_agent = headers
.get(axum::http::header::USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let api_event = crate::events::ApiCallDetail {
event_id: request_id.clone(),
event_source: format!("{}.amazonaws.com", meta.service),
event_name: operation.clone().unwrap_or_default(),
event_time_epoch: ts,
source_ip: state.request_details.get(&request_id).and_then(|d| {
d.request_headers.iter().find_map(|h| {
if h.name.eq_ignore_ascii_case("x-forwarded-for") {
Some(h.value.clone())
} else {
None
}
})
}),
user_agent,
user_identity_arn: principal_arn,
user_identity_account: Some(meta.account_id.clone()),
request_parameters: None,
response_elements: None,
error_code,
error_message: None,
http_status: status_code,
};
state
.event_bus
.publish_api_call(meta.region, meta.account_id, api_event);
}
(response, request_id)
}
async fn process_request(
state: &AppState,
method: &Method,
uri: &Uri,
headers: &HeaderMap,
body: &Bytes,
request_id: &str,
meta: &mut ProcessMeta,
) -> Result<ProcessOk, (Protocol, AwsError)> {
let (mut service_name, region, account_id, access_key) =
extract_service_info(state, headers, uri);
meta.region = region.clone();
meta.account_id = account_id.clone();
meta.access_key = access_key.clone();
if require_signed_requests_enabled() {
let protocol = protocol::detect_protocol(headers, body).unwrap_or(Protocol::RestJson1);
match access_key.as_deref() {
None => {
return Err((
protocol,
AwsError::bad_request(
"MissingAuthenticationTokenException",
"Request must be signed; no Authorization header found.",
),
));
}
Some(key)
if !state.authz.is_admin_access_key(key)
&& state
.authz
.principal_lookup
.resolve_access_key(key)
.is_none() =>
{
return Err((
protocol,
AwsError::bad_request(
"InvalidClientTokenId",
"The security token included in the request is invalid.",
),
));
}
_ => {}
}
}
if crate::sigv4_verify::verify_enabled()
&& let Some(key) = access_key.as_deref()
&& !state.authz.is_admin_access_key(key)
{
let protocol = protocol::detect_protocol(headers, body).unwrap_or(Protocol::RestJson1);
if let Err(err) = verify_signature_for_request(state, headers, method, uri, body, key) {
return Err((protocol, err));
}
}
if let Some(key) = access_key.as_deref()
&& !key.is_empty()
&& !state.authz.is_admin_access_key(key)
{
state
.authz
.principal_lookup
.record_access_key_used(key, &service_name, ®ion);
}
let mut handler = state.services.get(&service_name).ok_or_else(|| {
let protocol = protocol::detect_protocol(headers, body).unwrap_or(Protocol::RestJson1);
(
protocol,
AwsError::bad_request(
"UnknownService",
format!("Service '{service_name}' is not registered"),
),
)
})?;
let protocol = handler.protocol();
let mut detected = protocol::detect_protocol(headers, body).unwrap_or(protocol);
let empty_routes = Vec::new();
let routes = state.routes.get(&service_name).unwrap_or(&empty_routes);
let parsed = match protocol::parse_request(detected, method, uri, headers, body, routes) {
Ok(p) => p,
Err(e) if e.code == "UnknownOperationException" => {
if let Some(path_service) = resolve_service_from_path(uri.path())
&& path_service != service_name
&& let Some(fallback_handler) = state.services.get(&path_service)
{
let fallback_routes = state.routes.get(&path_service).unwrap_or(&empty_routes);
let fallback_protocol = fallback_handler.protocol();
let fallback_detected =
protocol::detect_protocol(headers, body).unwrap_or(fallback_protocol);
match protocol::parse_request(
fallback_detected,
method,
uri,
headers,
body,
fallback_routes,
) {
Ok(p) => {
debug!(
auth_service = %service_name,
path_service = %path_service,
"Auth-derived service had no matching route; falling back to path-derived service"
);
service_name = path_service;
handler = fallback_handler;
detected = fallback_detected;
let _ = fallback_protocol;
let _ = fallback_routes;
p
}
Err(_) => return Err((detected, e)),
}
} else {
return Err((detected, e));
}
}
Err(e) => return Err((detected, e)),
};
meta.service = service_name.clone();
debug!(
service = %service_name,
operation = %parsed.operation,
request_id = %request_id,
"Dispatching operation"
);
let source_ip = headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next().map(|s| s.trim().to_string()))
.filter(|s| !s.is_empty());
let is_secure = headers
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("https"))
.unwrap_or(false);
let ctx = crate::router::RequestContext {
account_id,
region,
partition: state.default_partition.clone(),
service: service_name.clone(),
access_key,
request_id: request_id.to_string(),
method: method.to_string(),
uri: uri.to_string(),
event_bus: Some(state.event_bus.clone()),
source_ip,
is_secure,
internal_bypass: false,
};
if let (Some(action), Some(resource)) = (
handler.iam_action(&parsed.operation),
handler.iam_resource(&parsed.operation, &parsed.input, &ctx),
) {
state
.authz
.check(&ctx, &action, &resource)
.map_err(|e| (detected, e))?;
} else {
debug!(
service = %service_name,
operation = %parsed.operation,
"Skipping IAM check — handler does not declare action/resource"
);
}
let operation = parsed.operation.clone();
if let Some(outcome) = state.chaos.evaluate(&service_name, Some(&operation)) {
if let Some(delay) = outcome.latency {
tokio::time::sleep(delay).await;
}
state
.chaos
.record_injection(&outcome.rule_id, &service_name, Some(&operation));
if let Some(err) = outcome.error {
let status = axum::http::StatusCode::from_u16(err.status)
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
let error_type = if status.is_server_error() {
crate::error::ErrorType::Receiver
} else {
crate::error::ErrorType::Sender
};
let aws_err = AwsError {
status,
code: err.code,
message: err.message,
error_type,
extras: None,
};
return Err((detected, aws_err));
}
}
let handler_result = handler
.handle_streaming(&parsed.operation, parsed.input, &ctx)
.await
.map_err(|e| (detected, e))?;
match handler_result {
crate::HandlerResult::Streaming { body, content_type } => {
let mut headers = HeaderMap::new();
if let Ok(v) = content_type.parse() {
headers.insert(axum::http::header::CONTENT_TYPE, v);
}
if let Ok(v) = request_id.parse() {
headers.insert("x-amzn-requestid", v);
}
Ok(ProcessOk {
status: StatusCode::OK,
headers,
body: ProcessBody::Stream(body),
operation,
})
}
crate::HandlerResult::Json(value) => {
let (status, headers, body) =
protocol::serialize_response(detected, &parsed.operation, &value, request_id);
Ok(ProcessOk {
status,
headers,
body: ProcessBody::Bytes(body),
operation,
})
}
}
}
fn verify_signature_for_request(
state: &AppState,
headers: &HeaderMap,
method: &Method,
uri: &Uri,
body: &Bytes,
access_key: &str,
) -> Result<(), AwsError> {
if let Some(query) = uri.query()
&& query.contains("X-Amz-Signature=")
{
return verify_presigned_for_request(state, headers, method, uri, access_key);
}
let auth_value = headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
AwsError::bad_request(
"MissingAuthenticationTokenException",
"Request must be signed with SigV4 when AWSIM_VERIFY_SIGV4 is on.",
)
})?;
let auth = crate::sigv4_verify::parse_authorization_header(auth_value).ok_or_else(|| {
AwsError::bad_request(
"IncompleteSignatureException",
"Authorization header is not in the expected AWS4-HMAC-SHA256 shape.",
)
})?;
let amz_date = headers
.get("x-amz-date")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
AwsError::bad_request(
"IncompleteSignatureException",
"x-amz-date header is required for SigV4-signed requests.",
)
})?;
let payload_hash_header = headers
.get("x-amz-content-sha256")
.and_then(|v| v.to_str().ok());
let secret = state
.authz
.principal_lookup
.resolve_secret(access_key)
.ok_or_else(|| {
AwsError::bad_request(
"InvalidClientTokenId",
"The security token included in the request is invalid.",
)
})?;
let mut headers_for_canonical: Vec<(String, String)> =
Vec::with_capacity(auth.signed_headers.len());
for name in &auth.signed_headers {
let lower = name.to_ascii_lowercase();
let value = if lower == "host" {
headers
.get("host")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.or_else(|| uri.host().map(|s| s.to_string()))
.unwrap_or_default()
} else {
headers
.get(lower.as_str())
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string()
};
headers_for_canonical.push((lower, value));
}
let path = uri.path();
let query = uri.query().unwrap_or("");
let canonical_query = canonicalize_query(query);
let outcome = crate::sigv4_verify::verify(
&auth,
&secret,
method.as_str(),
path,
&canonical_query,
&headers_for_canonical,
amz_date,
body,
payload_hash_header,
std::time::SystemTime::now(),
std::time::Duration::from_secs(300),
);
match outcome {
crate::sigv4_verify::VerifyOutcome::Ok => Ok(()),
crate::sigv4_verify::VerifyOutcome::IncompleteSignature => Err(AwsError::bad_request(
"IncompleteSignatureException",
"SigV4 verification failed: required header missing.",
)),
crate::sigv4_verify::VerifyOutcome::SignatureMismatch => Err(AwsError::forbidden(
"SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided.",
)),
}
}
fn verify_presigned_for_request(
state: &AppState,
headers: &HeaderMap,
method: &Method,
uri: &Uri,
access_key: &str,
) -> Result<(), AwsError> {
let raw_query = uri.query().unwrap_or("");
let host = headers
.get("host")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.or_else(|| uri.host().map(|s| s.to_string()))
.unwrap_or_default();
let secret = state
.authz
.principal_lookup
.resolve_secret(access_key)
.ok_or_else(|| {
AwsError::bad_request(
"InvalidClientTokenId",
"The security token included in the request is invalid.",
)
})?;
let outcome = crate::sigv4_verify::verify_presigned(
method.as_str(),
uri.path(),
raw_query,
&host,
&secret,
std::time::SystemTime::now(),
std::time::Duration::from_secs(300),
);
match outcome {
crate::sigv4_verify::VerifyOutcome::Ok => Ok(()),
crate::sigv4_verify::VerifyOutcome::IncompleteSignature => Err(AwsError::bad_request(
"IncompleteSignatureException",
"Presigned URL is missing one of X-Amz-Algorithm / X-Amz-Credential / X-Amz-Date / X-Amz-SignedHeaders / X-Amz-Signature.",
)),
crate::sigv4_verify::VerifyOutcome::SignatureMismatch => Err(AwsError::forbidden(
"SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided.",
)),
}
}
fn canonicalize_query(query: &str) -> String {
if query.is_empty() {
return String::new();
}
let mut parts: Vec<(String, String)> = query
.split('&')
.map(|kv| match kv.split_once('=') {
Some((k, v)) => (k.to_string(), v.to_string()),
None => (kv.to_string(), String::new()),
})
.collect();
parts.sort();
parts
.into_iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join("&")
}
fn extract_service_info(
state: &AppState,
headers: &HeaderMap,
uri: &Uri,
) -> (String, String, String, Option<String>) {
if let Some(auth_header) = headers.get("authorization").and_then(|v| v.to_str().ok())
&& let Some(creds) = auth::parse_authorization(auth_header)
{
return (
creds.service,
creds.region,
state.default_account_id.clone(),
Some(creds.access_key),
);
}
if let Some(query) = uri.query()
&& query.contains("X-Amz-Credential")
&& let Some(cred_start) = query.find("X-Amz-Credential=")
{
let cred_val = &query[cred_start + 17..];
let cred_end = cred_val.find('&').unwrap_or(cred_val.len());
let cred = &cred_val[..cred_end];
let cred_decoded = cred.replace("%2F", "/");
let parts: Vec<&str> = cred_decoded.split('/').collect();
if parts.len() >= 4 {
return (
parts[3].to_string(),
parts[2].to_string(),
state.default_account_id.clone(),
Some(parts[0].to_string()),
);
}
}
if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok())
&& let Some(service) = resolve_service_from_target(target)
{
return (
service,
state.default_region.clone(),
state.default_account_id.clone(),
None,
);
}
if let Some(host) = headers.get("host").and_then(|v| v.to_str().ok())
&& let Some(service) = extract_service_from_host(host, state)
{
return (
service,
state.default_region.clone(),
state.default_account_id.clone(),
None,
);
}
let path = uri.path();
if let Some(service) = resolve_service_from_path(path) {
return (
service,
state.default_region.clone(),
state.default_account_id.clone(),
None,
);
}
warn!(
auth = ?headers.get("authorization").map(|v| v.to_str().unwrap_or("<non-utf8>")),
target = ?headers.get("x-amz-target").map(|v| v.to_str().unwrap_or("<non-utf8>")),
host = ?headers.get("host").map(|v| v.to_str().unwrap_or("<non-utf8>")),
path = %path,
"Could not determine service — falling back to 'unknown'"
);
(
"unknown".to_string(),
state.default_region.clone(),
state.default_account_id.clone(),
None,
)
}
fn resolve_service_from_target(target: &str) -> Option<String> {
let prefix = target.split('.').next()?;
let service = match prefix {
p if p.starts_with("DynamoDB") => "dynamodb",
p if p.starts_with("AmazonSQS") => "sqs",
p if p.starts_with("AmazonSNS") => "sns",
p if p.starts_with("TrentService") => "kms",
p if p.starts_with("secretsmanager") => "secretsmanager",
p if p.starts_with("AmazonSSM") => "ssm",
p if p.starts_with("Logs") => "logs",
p if p.starts_with("Kinesis") => "kinesis",
p if p.starts_with("AWSStepFunctions") => "states",
p if p.starts_with("AWSEvents") => "events",
p if p.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
p if p.starts_with("AWSCognitoIdentityService") => "cognito-identity",
p if p.starts_with("AmazonEC2ContainerServiceV2") => "ecs",
p if p.starts_with("AmazonEC2ContainerRegistry") => "ecr",
p if p.starts_with("AmazonAthena") => "athena",
p if p.starts_with("AWSGlue") => "glue",
p if p.starts_with("CertificateManager") => "acm",
p if p.starts_with("AWSWAF") => "wafv2",
p if p.starts_with("Comprehend") => "comprehend",
p if p.starts_with("kendra") => "kendra",
p if p.starts_with("AWSOrganizationsV") => "organizations",
p if p.starts_with("CloudTrail_") => "cloudtrail",
p if p.starts_with("Firehose_") => "firehose",
p if p.starts_with("ResourceGroupsTaggingAPI") => "tagging",
p if p.starts_with("AnyScaleFrontendService") => "application-autoscaling",
p if p.starts_with("Route53AutoNaming_v") => "servicediscovery",
p if p.starts_with("AmazonMemoryDB") => "memorydb",
_ => return None,
};
Some(service.to_string())
}
fn extract_service_from_host(host: &str, state: &AppState) -> Option<String> {
let host = host.split(':').next().unwrap_or(host);
for part in host.split('.') {
if state.services.contains_key(part) {
return Some(part.to_string());
}
}
None
}
fn resolve_service_from_path(path: &str) -> Option<String> {
let service = match path {
p if p.starts_with("/2015-03-31/functions") || p.starts_with("/2018-10-31/layers") => {
"lambda"
}
p if p.starts_with("/v2/apis") => "execute-api",
p if p.starts_with("/v2/email") => "ses",
p if p.starts_with("/2013-04-01/hostedzone")
|| p.starts_with("/2013-04-01/healthcheck")
|| p.starts_with("/2013-04-01/tags") =>
{
"route53"
}
p if p.starts_with("/2020-05-31/distribution")
|| p.starts_with("/2020-05-31/origin-access-control")
|| p.starts_with("/2020-05-31/cache-policy")
|| p.starts_with("/2020-05-31/tagging") =>
{
"cloudfront"
}
p if p.starts_with("/v1/apis") => "appsync",
p if p.starts_with("/foundation-models")
|| p.starts_with("/guardrails")
|| p.starts_with("/model-customization") =>
{
"bedrock"
}
p if p.starts_with("/model/") => "bedrock-runtime",
p if p.starts_with("/schedules") || p.starts_with("/schedule-groups") => "scheduler",
p if p.starts_with("/clusters") || p == "/tags" || p.starts_with("/tags/") => "eks",
_ => return None,
};
Some(service.to_string())
}
#[cfg(test)]
mod browser_probe_tests {
use super::*;
#[test]
fn matches_known_probes() {
assert!(is_browser_probe(&Method::GET, "/favicon.ico"));
assert!(is_browser_probe(
&Method::GET,
"/.well-known/appspecific/com.chrome.devtools.json"
));
assert!(is_browser_probe(&Method::GET, "/robots.txt"));
assert!(is_browser_probe(&Method::GET, "/apple-touch-icon.png"));
}
#[test]
fn ignores_non_get_methods() {
assert!(!is_browser_probe(&Method::PUT, "/favicon.ico"));
assert!(!is_browser_probe(&Method::POST, "/favicon.ico"));
}
#[test]
fn ignores_unknown_paths() {
assert!(!is_browser_probe(&Method::GET, "/"));
assert!(!is_browser_probe(&Method::GET, "/some-bucket/key"));
assert!(!is_browser_probe(&Method::GET, "/_awsim/stats"));
}
}
#[cfg(test)]
mod tick_tests {
use super::*;
use crate::RequestContext;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
struct CountingService {
ticks: Arc<AtomicU64>,
}
#[async_trait::async_trait]
impl ServiceHandler for CountingService {
fn service_name(&self) -> &str {
"test"
}
fn protocol(&self) -> Protocol {
Protocol::AwsJson1_1
}
async fn handle(
&self,
_: &str,
_: serde_json::Value,
_: &RequestContext,
) -> Result<serde_json::Value, AwsError> {
Ok(serde_json::Value::Null)
}
async fn tick(&self) {
self.ticks.fetch_add(1, Ordering::SeqCst);
}
}
#[tokio::test]
async fn tick_loop_invokes_each_registered_service() {
let counter = Arc::new(AtomicU64::new(0));
let svc = Arc::new(CountingService {
ticks: counter.clone(),
}) as Arc<dyn ServiceHandler>;
let mut services: HashMap<String, Arc<dyn ServiceHandler>> = HashMap::new();
services.insert("test".to_string(), svc);
let mut state = AppState::new("us-east-1".to_string(), "000000000000".to_string());
state.services = Arc::new(services);
let handle = spawn_tick_loop(state, Duration::from_millis(50));
tokio::time::sleep(Duration::from_millis(200)).await;
handle.abort();
let count = counter.load(Ordering::SeqCst);
assert!(count >= 2, "expected at least 2 ticks, got {count}");
}
struct PanickingService;
#[async_trait::async_trait]
impl ServiceHandler for PanickingService {
fn service_name(&self) -> &str {
"panicky"
}
fn protocol(&self) -> Protocol {
Protocol::AwsJson1_1
}
async fn handle(
&self,
_: &str,
_: serde_json::Value,
_: &RequestContext,
) -> Result<serde_json::Value, AwsError> {
Ok(serde_json::Value::Null)
}
async fn tick(&self) {
panic!("intentional test panic from tick");
}
}
#[tokio::test]
async fn panicking_service_does_not_stop_other_services_ticking() {
let counter = Arc::new(AtomicU64::new(0));
let counting = Arc::new(CountingService {
ticks: counter.clone(),
}) as Arc<dyn ServiceHandler>;
let panicky = Arc::new(PanickingService) as Arc<dyn ServiceHandler>;
let mut services: HashMap<String, Arc<dyn ServiceHandler>> = HashMap::new();
services.insert("counting".to_string(), counting);
services.insert("panicky".to_string(), panicky);
let mut state = AppState::new("us-east-1".to_string(), "000000000000".to_string());
state.services = Arc::new(services);
let handle = spawn_tick_loop(state, Duration::from_millis(30));
tokio::time::sleep(Duration::from_millis(200)).await;
handle.abort();
let count = counter.load(Ordering::SeqCst);
assert!(
count >= 3,
"counting service should have continued ticking despite sibling panic; got {count}"
);
}
struct SlowService {
ticks: Arc<AtomicU64>,
}
#[async_trait::async_trait]
impl ServiceHandler for SlowService {
fn service_name(&self) -> &str {
"slow"
}
fn protocol(&self) -> Protocol {
Protocol::AwsJson1_1
}
async fn handle(
&self,
_: &str,
_: serde_json::Value,
_: &RequestContext,
) -> Result<serde_json::Value, AwsError> {
Ok(serde_json::Value::Null)
}
async fn tick(&self) {
self.ticks.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_secs(60)).await;
}
}
struct FixedSecretLookup {
secret: String,
}
impl crate::authz::PrincipalLookup for FixedSecretLookup {
fn resolve_access_key(&self, _: &str) -> Option<crate::authz::ResolvedPrincipal> {
None
}
fn resolve_secret(&self, _: &str) -> Option<String> {
Some(self.secret.clone())
}
}
#[test]
fn presigned_url_tampering_surfaces_403_forbidden() {
use axum::http::HeaderValue;
let secret = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY";
let mut state = AppState::new("us-east-1".to_string(), "000000000000".to_string());
let mut authz = crate::authz::AuthzEngine::new(false);
authz.principal_lookup = Arc::new(FixedSecretLookup {
secret: secret.to_string(),
});
state.authz = Arc::new(authz);
let raw_query = "X-Amz-Algorithm=AWS4-HMAC-SHA256\
&X-Amz-Credential=AKID%2F20260524%2Fus-east-1%2Fs3%2Faws4_request\
&X-Amz-Date=20260524T120000Z\
&X-Amz-Expires=900\
&X-Amz-SignedHeaders=host\
&X-Amz-Signature=00000000000000000000000000000000\
00000000000000000000000000000000";
let uri: Uri = format!("/bucket/key?{raw_query}").parse().unwrap();
let mut headers = HeaderMap::new();
headers.insert("host", HeaderValue::from_static("s3.amazonaws.com"));
let err = verify_presigned_for_request(&state, &headers, &Method::GET, &uri, "AKID")
.expect_err("tampered presigned URL must be rejected");
assert_eq!(err.code, "SignatureDoesNotMatch");
assert_eq!(
err.status,
StatusCode::FORBIDDEN,
"tampered presigned URL must surface as HTTP 403 (real AWS behaviour)"
);
}
#[tokio::test]
async fn slow_service_is_timed_out_so_loop_keeps_running() {
let slow_ticks = Arc::new(AtomicU64::new(0));
let fast_ticks = Arc::new(AtomicU64::new(0));
let slow = Arc::new(SlowService {
ticks: slow_ticks.clone(),
}) as Arc<dyn ServiceHandler>;
let fast = Arc::new(CountingService {
ticks: fast_ticks.clone(),
}) as Arc<dyn ServiceHandler>;
let mut services: HashMap<String, Arc<dyn ServiceHandler>> = HashMap::new();
services.insert("slow".to_string(), slow);
services.insert("fast".to_string(), fast);
let mut state = AppState::new("us-east-1".to_string(), "000000000000".to_string());
state.services = Arc::new(services);
let handle = spawn_tick_loop(state, Duration::from_millis(80));
tokio::time::sleep(Duration::from_millis(400)).await;
handle.abort();
let fast = fast_ticks.load(Ordering::SeqCst);
assert!(
fast >= 3,
"fast service should keep ticking past the slow one; got {fast}"
);
}
}