use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use axum::body::Body;
use axum::extract::State;
use axum::http::{HeaderMap, Method, Response, StatusCode, Uri};
use bytes::Bytes;
use tracing::{debug, info, warn};
use crate::ServiceHandler;
use crate::auth;
use crate::authz::AuthzEngine;
use crate::body_store::BodyStore;
use crate::error::AwsError;
use crate::events::EventBus;
use crate::protocol::{self, Protocol, RouteDefinition};
use crate::request_detail::{
CapturedBody, RequestDetail, RequestDetailStore, capture_body, capture_headers,
};
use crate::request_event::{RequestEvent, RequestEventBus};
#[derive(Clone)]
pub struct BodyStoreHandle {
pub service_name: String,
pub groups: Vec<String>,
pub body_store: Arc<BodyStore>,
}
#[derive(Clone)]
pub struct AppState {
pub services: Arc<HashMap<String, Arc<dyn ServiceHandler>>>,
pub routes: Arc<HashMap<String, Vec<RouteDefinition>>>,
pub default_region: String,
pub default_account_id: String,
pub event_bus: EventBus,
pub request_count: Arc<AtomicU64>,
pub start_time: std::time::Instant,
pub authz: Arc<AuthzEngine>,
pub body_stores: Arc<Vec<BodyStoreHandle>>,
pub data_dir: Option<Arc<std::path::PathBuf>>,
pub events: RequestEventBus,
pub request_details: RequestDetailStore,
pub chaos: Arc<awsim_chaos::ChaosEngine>,
}
impl AppState {
pub fn new(default_region: String, default_account_id: String) -> Self {
Self {
services: Arc::new(HashMap::new()),
routes: Arc::new(HashMap::new()),
default_region,
default_account_id,
event_bus: EventBus::new(),
request_count: Arc::new(AtomicU64::new(0)),
start_time: std::time::Instant::now(),
authz: Arc::new(AuthzEngine::from_env()),
body_stores: Arc::new(Vec::new()),
data_dir: None,
events: RequestEventBus::new(),
request_details: RequestDetailStore::default(),
chaos: Arc::new(awsim_chaos::ChaosEngine::new()),
}
}
pub fn register(&mut self, handler: Arc<dyn ServiceHandler>, routes: Vec<RouteDefinition>) {
let signing_name = handler.signing_name().to_string();
let service_name = handler.service_name().to_string();
info!(
service = %service_name,
signing_name = %signing_name,
protocol = ?handler.protocol(),
routes = routes.len(),
"Registered service"
);
Arc::get_mut(&mut self.services)
.unwrap()
.insert(signing_name.clone(), handler);
if !routes.is_empty() {
Arc::get_mut(&mut self.routes)
.unwrap()
.insert(signing_name, routes);
}
}
}
struct ProcessOk {
status: StatusCode,
headers: HeaderMap,
body: ProcessBody,
operation: String,
}
enum ProcessBody {
Bytes(Bytes),
Stream(crate::HandlerByteStream),
}
impl ProcessBody {
fn buffered_len(&self) -> Option<usize> {
match self {
ProcessBody::Bytes(b) => Some(b.len()),
ProcessBody::Stream(_) => None,
}
}
}
struct ProcessMeta {
service: String,
region: String,
account_id: String,
access_key: Option<String>,
}
pub fn spawn_tick_loop(
state: AppState,
interval: std::time::Duration,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
ticker.tick().await; loop {
ticker.tick().await;
let services: Vec<Arc<dyn ServiceHandler>> = state.services.values().cloned().collect();
for svc in services {
svc.tick().await;
}
}
})
}
pub async fn handle_request(
State(state): State<AppState>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Bytes,
) -> Response<Body> {
if is_browser_probe(&method, uri.path()) {
return Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Body::empty())
.unwrap();
}
let (response, _id) = dispatch_request(&state, method, uri, headers, body).await;
response
}
fn is_browser_probe(method: &Method, path: &str) -> bool {
if method != Method::GET {
return false;
}
matches!(
path,
"/favicon.ico"
| "/apple-touch-icon.png"
| "/apple-touch-icon-precomposed.png"
| "/robots.txt"
| "/.well-known/appspecific/com.chrome.devtools.json"
)
}
pub async fn dispatch_request(
state: &AppState,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Bytes,
) -> (Response<Body>, String) {
state.request_count.fetch_add(1, Ordering::Relaxed);
let request_id = uuid::Uuid::new_v4().to_string();
let started = Instant::now();
let request_size = body.len() as u64;
debug!(
method = %method,
uri = %uri,
request_id = %request_id,
"Incoming request"
);
let mut meta = ProcessMeta {
service: String::new(),
region: state.default_region.clone(),
account_id: state.default_account_id.clone(),
access_key: None,
};
let outcome = process_request(
state,
&method,
&uri,
&headers,
&body,
&request_id,
&mut meta,
)
.await;
let (status, mut resp_headers, resp_body, operation, error_code) = match outcome {
Ok(ProcessOk {
status,
headers,
body,
operation,
}) => (status, headers, body, Some(operation), None),
Err((protocol, error)) => {
warn!(
error_code = %error.code,
error_message = %error.message,
request_id = %request_id,
"Request failed"
);
let err_code = error.code.clone();
let (status, resp_headers, resp_body) =
protocol::serialize_error(protocol, &error, &request_id);
(
status,
resp_headers,
ProcessBody::Bytes(resp_body),
None,
Some(err_code),
)
}
};
let status_code = status.as_u16();
let response_size = resp_body.buffered_len().unwrap_or(0) as u64;
let body_cap = state.request_details.body_cap();
let captured_response = match &resp_body {
ProcessBody::Bytes(b) => capture_body(b, body_cap),
ProcessBody::Stream(_) => CapturedBody::placeholder("<streaming response>"),
};
let detail = RequestDetail {
id: request_id.clone(),
method: method.to_string(),
path: uri.path().to_string(),
query: uri.query().map(|q| q.to_string()),
status_code,
request_headers: capture_headers(&headers),
response_headers: capture_headers(&resp_headers),
request_body: capture_body(&body, body_cap),
response_body: captured_response,
};
state.request_details.insert(detail);
let memory_mb = resp_headers
.remove("x-awsim-memory-mb")
.and_then(|v| v.to_str().ok().and_then(|s| s.parse::<u32>().ok()));
let state_transitions = resp_headers
.remove("x-awsim-state-transitions")
.and_then(|v| v.to_str().ok().and_then(|s| s.parse::<u32>().ok()));
let character_count = resp_headers
.remove("x-awsim-char-count")
.and_then(|v| v.to_str().ok().and_then(|s| s.parse::<u64>().ok()));
let mut builder = Response::builder().status(status);
for (key, value) in resp_headers.drain() {
if let Some(key) = key {
builder = builder.header(key, value);
}
}
let body_for_response = match resp_body {
ProcessBody::Bytes(b) => Body::from(b),
ProcessBody::Stream(s) => {
use futures::StreamExt;
let mapped = s.map(|res| match res {
Ok(b) => Ok::<_, std::io::Error>(b),
Err(e) => {
let payload = format!("{{\"error\":\"{}\"}}", e.message);
Ok(Bytes::from(payload))
}
});
Body::from_stream(mapped)
}
};
let response = builder.body(body_for_response).unwrap();
let duration_ms = started.elapsed().as_secs_f64() * 1000.0;
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.unwrap_or(0.0);
let principal_arn = meta
.access_key
.as_ref()
.map(|ak| format!("arn:aws:iam::{}:access-key/{}", meta.account_id, ak));
let event = RequestEvent {
id: request_id.clone(),
ts,
method: method.to_string(),
path: uri.path().to_string(),
service: meta.service,
operation,
account_id: meta.account_id,
region: meta.region,
principal_arn,
status_code,
duration_ms,
request_size,
response_size,
error_code,
memory_mb,
state_transitions,
character_count,
};
state.events.publish(event);
(response, request_id)
}
async fn process_request(
state: &AppState,
method: &Method,
uri: &Uri,
headers: &HeaderMap,
body: &Bytes,
request_id: &str,
meta: &mut ProcessMeta,
) -> Result<ProcessOk, (Protocol, AwsError)> {
let (mut service_name, region, account_id, access_key) =
extract_service_info(state, headers, uri);
meta.region = region.clone();
meta.account_id = account_id.clone();
meta.access_key = access_key.clone();
let mut handler = state.services.get(&service_name).ok_or_else(|| {
let protocol = protocol::detect_protocol(headers, body).unwrap_or(Protocol::RestJson1);
(
protocol,
AwsError::bad_request(
"UnknownService",
format!("Service '{service_name}' is not registered"),
),
)
})?;
let protocol = handler.protocol();
let mut detected = protocol::detect_protocol(headers, body).unwrap_or(protocol);
let empty_routes = Vec::new();
let routes = state.routes.get(&service_name).unwrap_or(&empty_routes);
let parsed = match protocol::parse_request(detected, method, uri, headers, body, routes) {
Ok(p) => p,
Err(e) if e.code == "UnknownOperationException" => {
if let Some(path_service) = resolve_service_from_path(uri.path())
&& path_service != service_name
&& let Some(fallback_handler) = state.services.get(&path_service)
{
let fallback_routes = state.routes.get(&path_service).unwrap_or(&empty_routes);
let fallback_protocol = fallback_handler.protocol();
let fallback_detected =
protocol::detect_protocol(headers, body).unwrap_or(fallback_protocol);
match protocol::parse_request(
fallback_detected,
method,
uri,
headers,
body,
fallback_routes,
) {
Ok(p) => {
debug!(
auth_service = %service_name,
path_service = %path_service,
"Auth-derived service had no matching route; falling back to path-derived service"
);
service_name = path_service;
handler = fallback_handler;
detected = fallback_detected;
let _ = fallback_protocol;
let _ = fallback_routes;
p
}
Err(_) => return Err((detected, e)),
}
} else {
return Err((detected, e));
}
}
Err(e) => return Err((detected, e)),
};
meta.service = service_name.clone();
debug!(
service = %service_name,
operation = %parsed.operation,
request_id = %request_id,
"Dispatching operation"
);
let source_ip = headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next().map(|s| s.trim().to_string()))
.filter(|s| !s.is_empty());
let is_secure = headers
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("https"))
.unwrap_or(false);
let ctx = crate::router::RequestContext {
account_id,
region,
service: service_name.clone(),
access_key,
request_id: request_id.to_string(),
method: method.to_string(),
uri: uri.to_string(),
event_bus: Some(state.event_bus.clone()),
source_ip,
is_secure,
};
if let (Some(action), Some(resource)) = (
handler.iam_action(&parsed.operation),
handler.iam_resource(&parsed.operation, &parsed.input, &ctx),
) {
state
.authz
.check(&ctx, &action, &resource)
.map_err(|e| (detected, e))?;
} else {
debug!(
service = %service_name,
operation = %parsed.operation,
"Skipping IAM check — handler does not declare action/resource"
);
}
let operation = parsed.operation.clone();
if let Some(outcome) = state.chaos.evaluate(&service_name, Some(&operation)) {
if let Some(delay) = outcome.latency {
tokio::time::sleep(delay).await;
}
state
.chaos
.record_injection(&outcome.rule_id, &service_name, Some(&operation));
if let Some(err) = outcome.error {
let status = axum::http::StatusCode::from_u16(err.status)
.unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
let error_type = if status.is_server_error() {
crate::error::ErrorType::Receiver
} else {
crate::error::ErrorType::Sender
};
let aws_err = AwsError {
status,
code: err.code,
message: err.message,
error_type,
extras: None,
};
return Err((detected, aws_err));
}
}
let handler_result = handler
.handle_streaming(&parsed.operation, parsed.input, &ctx)
.await
.map_err(|e| (detected, e))?;
match handler_result {
crate::HandlerResult::Streaming { body, content_type } => {
let mut headers = HeaderMap::new();
if let Ok(v) = content_type.parse() {
headers.insert(axum::http::header::CONTENT_TYPE, v);
}
if let Ok(v) = request_id.parse() {
headers.insert("x-amzn-requestid", v);
}
Ok(ProcessOk {
status: StatusCode::OK,
headers,
body: ProcessBody::Stream(body),
operation,
})
}
crate::HandlerResult::Json(value) => {
let (status, headers, body) =
protocol::serialize_response(detected, &parsed.operation, &value, request_id);
Ok(ProcessOk {
status,
headers,
body: ProcessBody::Bytes(body),
operation,
})
}
}
}
fn extract_service_info(
state: &AppState,
headers: &HeaderMap,
uri: &Uri,
) -> (String, String, String, Option<String>) {
if let Some(auth_header) = headers.get("authorization").and_then(|v| v.to_str().ok())
&& let Some(creds) = auth::parse_authorization(auth_header)
{
return (
creds.service,
creds.region,
state.default_account_id.clone(),
Some(creds.access_key),
);
}
if let Some(query) = uri.query()
&& query.contains("X-Amz-Credential")
&& let Some(cred_start) = query.find("X-Amz-Credential=")
{
let cred_val = &query[cred_start + 17..];
let cred_end = cred_val.find('&').unwrap_or(cred_val.len());
let cred = &cred_val[..cred_end];
let cred_decoded = cred.replace("%2F", "/");
let parts: Vec<&str> = cred_decoded.split('/').collect();
if parts.len() >= 4 {
return (
parts[3].to_string(),
parts[2].to_string(),
state.default_account_id.clone(),
Some(parts[0].to_string()),
);
}
}
if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok())
&& let Some(service) = resolve_service_from_target(target)
{
return (
service,
state.default_region.clone(),
state.default_account_id.clone(),
None,
);
}
if let Some(host) = headers.get("host").and_then(|v| v.to_str().ok())
&& let Some(service) = extract_service_from_host(host, state)
{
return (
service,
state.default_region.clone(),
state.default_account_id.clone(),
None,
);
}
let path = uri.path();
if let Some(service) = resolve_service_from_path(path) {
return (
service,
state.default_region.clone(),
state.default_account_id.clone(),
None,
);
}
warn!(
auth = ?headers.get("authorization").map(|v| v.to_str().unwrap_or("<non-utf8>")),
target = ?headers.get("x-amz-target").map(|v| v.to_str().unwrap_or("<non-utf8>")),
host = ?headers.get("host").map(|v| v.to_str().unwrap_or("<non-utf8>")),
path = %path,
"Could not determine service — falling back to 'unknown'"
);
(
"unknown".to_string(),
state.default_region.clone(),
state.default_account_id.clone(),
None,
)
}
fn resolve_service_from_target(target: &str) -> Option<String> {
let prefix = target.split('.').next()?;
let service = match prefix {
p if p.starts_with("DynamoDB") => "dynamodb",
p if p.starts_with("AmazonSQS") => "sqs",
p if p.starts_with("AmazonSNS") => "sns",
p if p.starts_with("TrentService") => "kms",
p if p.starts_with("secretsmanager") => "secretsmanager",
p if p.starts_with("AmazonSSM") => "ssm",
p if p.starts_with("Logs") => "logs",
p if p.starts_with("Kinesis") => "kinesis",
p if p.starts_with("AWSStepFunctions") => "states",
p if p.starts_with("AWSEvents") => "events",
p if p.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
p if p.starts_with("AWSCognitoIdentityService") => "cognito-identity",
p if p.starts_with("AmazonEC2ContainerServiceV2") => "ecs",
p if p.starts_with("AmazonEC2ContainerRegistry") => "ecr",
p if p.starts_with("AmazonAthena") => "athena",
p if p.starts_with("AWSGlue") => "glue",
p if p.starts_with("CertificateManager") => "acm",
p if p.starts_with("AWSWAF") => "wafv2",
p if p.starts_with("Comprehend") => "comprehend",
p if p.starts_with("kendra") => "kendra",
p if p.starts_with("AWSOrganizationsV") => "organizations",
p if p.starts_with("CloudTrail_") => "cloudtrail",
p if p.starts_with("Firehose_") => "firehose",
p if p.starts_with("ResourceGroupsTaggingAPI") => "tagging",
p if p.starts_with("AnyScaleFrontendService") => "application-autoscaling",
p if p.starts_with("Route53AutoNaming_v") => "servicediscovery",
p if p.starts_with("AmazonMemoryDB") => "memorydb",
_ => return None,
};
Some(service.to_string())
}
fn extract_service_from_host(host: &str, state: &AppState) -> Option<String> {
let host = host.split(':').next().unwrap_or(host);
for part in host.split('.') {
if state.services.contains_key(part) {
return Some(part.to_string());
}
}
None
}
fn resolve_service_from_path(path: &str) -> Option<String> {
let service = match path {
p if p.starts_with("/2015-03-31/functions") || p.starts_with("/2018-10-31/layers") => {
"lambda"
}
p if p.starts_with("/v2/apis") => "execute-api",
p if p.starts_with("/v2/email") => "ses",
p if p.starts_with("/2013-04-01/hostedzone")
|| p.starts_with("/2013-04-01/healthcheck")
|| p.starts_with("/2013-04-01/tags") =>
{
"route53"
}
p if p.starts_with("/2020-05-31/distribution")
|| p.starts_with("/2020-05-31/origin-access-control")
|| p.starts_with("/2020-05-31/cache-policy")
|| p.starts_with("/2020-05-31/tagging") =>
{
"cloudfront"
}
p if p.starts_with("/v1/apis") => "appsync",
p if p.starts_with("/foundation-models")
|| p.starts_with("/guardrails")
|| p.starts_with("/model-customization") =>
{
"bedrock"
}
p if p.starts_with("/model/") => "bedrock-runtime",
p if p.starts_with("/schedules") || p.starts_with("/schedule-groups") => "scheduler",
p if p.starts_with("/clusters") || p == "/tags" || p.starts_with("/tags/") => "eks",
_ => return None,
};
Some(service.to_string())
}
#[cfg(test)]
mod browser_probe_tests {
use super::*;
#[test]
fn matches_known_probes() {
assert!(is_browser_probe(&Method::GET, "/favicon.ico"));
assert!(is_browser_probe(
&Method::GET,
"/.well-known/appspecific/com.chrome.devtools.json"
));
assert!(is_browser_probe(&Method::GET, "/robots.txt"));
assert!(is_browser_probe(&Method::GET, "/apple-touch-icon.png"));
}
#[test]
fn ignores_non_get_methods() {
assert!(!is_browser_probe(&Method::PUT, "/favicon.ico"));
assert!(!is_browser_probe(&Method::POST, "/favicon.ico"));
}
#[test]
fn ignores_unknown_paths() {
assert!(!is_browser_probe(&Method::GET, "/"));
assert!(!is_browser_probe(&Method::GET, "/some-bucket/key"));
assert!(!is_browser_probe(&Method::GET, "/_awsim/stats"));
}
}
#[cfg(test)]
mod tick_tests {
use super::*;
use crate::RequestContext;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
struct CountingService {
ticks: Arc<AtomicU64>,
}
#[async_trait::async_trait]
impl ServiceHandler for CountingService {
fn service_name(&self) -> &str {
"test"
}
fn protocol(&self) -> Protocol {
Protocol::AwsJson1_1
}
async fn handle(
&self,
_: &str,
_: serde_json::Value,
_: &RequestContext,
) -> Result<serde_json::Value, AwsError> {
Ok(serde_json::Value::Null)
}
async fn tick(&self) {
self.ticks.fetch_add(1, Ordering::SeqCst);
}
}
#[tokio::test]
async fn tick_loop_invokes_each_registered_service() {
let counter = Arc::new(AtomicU64::new(0));
let svc = Arc::new(CountingService {
ticks: counter.clone(),
}) as Arc<dyn ServiceHandler>;
let mut services: HashMap<String, Arc<dyn ServiceHandler>> = HashMap::new();
services.insert("test".to_string(), svc);
let mut state = AppState::new("us-east-1".to_string(), "000000000000".to_string());
state.services = Arc::new(services);
let handle = spawn_tick_loop(state, Duration::from_millis(50));
tokio::time::sleep(Duration::from_millis(200)).await;
handle.abort();
let count = counter.load(Ordering::SeqCst);
assert!(count >= 2, "expected at least 2 ticks, got {count}");
}
}