Skip to main content

awsim_core/
gateway.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::{Instant, SystemTime, UNIX_EPOCH};
5
6use axum::body::Body;
7use axum::extract::State;
8use axum::http::{HeaderMap, Method, Response, StatusCode, Uri};
9use bytes::Bytes;
10use tracing::{debug, info, warn};
11
12use crate::ServiceHandler;
13use crate::auth;
14use crate::authz::AuthzEngine;
15use crate::body_store::BodyStore;
16use crate::error::AwsError;
17use crate::events::EventBus;
18use crate::protocol::{self, Protocol, RouteDefinition};
19use crate::request_detail::{RequestDetail, RequestDetailStore, capture_body, capture_headers};
20use crate::request_event::{RequestEvent, RequestEventBus};
21
22#[derive(Clone)]
23pub struct BodyStoreHandle {
24    pub service_name: String,
25    pub groups: Vec<String>,
26    pub body_store: Arc<BodyStore>,
27}
28
29/// Shared application state passed to all request handlers.
30#[derive(Clone)]
31pub struct AppState {
32    /// Registered service handlers, keyed by signing name.
33    pub services: Arc<HashMap<String, Arc<dyn ServiceHandler>>>,
34    /// Route definitions for REST-protocol services, keyed by signing name.
35    pub routes: Arc<HashMap<String, Vec<RouteDefinition>>>,
36    /// Default AWS region.
37    pub default_region: String,
38    /// Default AWS account ID.
39    pub default_account_id: String,
40    /// Internal event bus for cross-service fan-out (SNS→SQS, etc.).
41    pub event_bus: EventBus,
42    /// Total number of AWS API requests handled since startup.
43    pub request_count: Arc<AtomicU64>,
44    /// Server startup time.
45    pub start_time: std::time::Instant,
46    /// IAM authorization engine — opt-in via AWSIM_IAM_ENFORCE=true.
47    pub authz: Arc<AuthzEngine>,
48    /// Per-service `BodyStore` handles, populated when persistence is enabled.
49    pub body_stores: Arc<Vec<BodyStoreHandle>>,
50    /// Persistence root directory, when persistence is enabled.
51    pub data_dir: Option<Arc<std::path::PathBuf>>,
52    /// Broadcast bus for per-request observability events (consumed by SSE).
53    pub events: RequestEventBus,
54    /// Ring buffer of recent per-request detail captures (headers + bodies).
55    pub request_details: RequestDetailStore,
56    /// Chaos engine — empty by default, populated by admin endpoints.
57    /// Evaluated before dispatch to inject synthetic errors / latency.
58    pub chaos: Arc<awsim_chaos::ChaosEngine>,
59}
60
61impl AppState {
62    pub fn new(default_region: String, default_account_id: String) -> Self {
63        Self {
64            services: Arc::new(HashMap::new()),
65            routes: Arc::new(HashMap::new()),
66            default_region,
67            default_account_id,
68            event_bus: EventBus::new(),
69            request_count: Arc::new(AtomicU64::new(0)),
70            start_time: std::time::Instant::now(),
71            authz: Arc::new(AuthzEngine::from_env()),
72            body_stores: Arc::new(Vec::new()),
73            data_dir: None,
74            events: RequestEventBus::new(),
75            request_details: RequestDetailStore::default(),
76            chaos: Arc::new(awsim_chaos::ChaosEngine::new()),
77        }
78    }
79
80    /// Register a service handler.
81    pub fn register(&mut self, handler: Arc<dyn ServiceHandler>, routes: Vec<RouteDefinition>) {
82        let signing_name = handler.signing_name().to_string();
83        let service_name = handler.service_name().to_string();
84
85        info!(
86            service = %service_name,
87            signing_name = %signing_name,
88            protocol = ?handler.protocol(),
89            routes = routes.len(),
90            "Registered service"
91        );
92
93        Arc::get_mut(&mut self.services)
94            .unwrap()
95            .insert(signing_name.clone(), handler);
96
97        if !routes.is_empty() {
98            Arc::get_mut(&mut self.routes)
99                .unwrap()
100                .insert(signing_name, routes);
101        }
102    }
103}
104
105struct ProcessOk {
106    status: StatusCode,
107    headers: HeaderMap,
108    body: Bytes,
109    operation: String,
110}
111
112struct ProcessMeta {
113    service: String,
114    region: String,
115    account_id: String,
116    access_key: Option<String>,
117}
118
119/// Main request handler — all AWS API requests funnel through here.
120pub async fn handle_request(
121    State(state): State<AppState>,
122    method: Method,
123    uri: Uri,
124    headers: HeaderMap,
125    body: Bytes,
126) -> Response<Body> {
127    // Short-circuit common browser probes (favicon, devtools well-known
128    // path) before the AWS dispatch pipeline runs. They are not API
129    // calls — silently 204'ing keeps them out of the request log and
130    // out of the inspect drawer.
131    if is_browser_probe(&method, uri.path()) {
132        return Response::builder()
133            .status(StatusCode::NO_CONTENT)
134            .body(Body::empty())
135            .unwrap();
136    }
137    let (response, _id) = dispatch_request(&state, method, uri, headers, body).await;
138    response
139}
140
141/// Recognise the handful of unsolicited paths browsers hit when you point
142/// them at AWSim's port — they're not AWS requests and shouldn't appear
143/// in logs or stats. Conservative on purpose: only paths we've actually
144/// seen in real traces are listed.
145fn is_browser_probe(method: &Method, path: &str) -> bool {
146    if method != Method::GET {
147        return false;
148    }
149    matches!(
150        path,
151        "/favicon.ico"
152            | "/apple-touch-icon.png"
153            | "/apple-touch-icon-precomposed.png"
154            | "/robots.txt"
155            | "/.well-known/appspecific/com.chrome.devtools.json"
156    )
157}
158
159/// Same as `handle_request`, but takes the state by reference and returns
160/// the generated request id alongside the response. Lets internal callers
161/// (replay, etc.) drive the gateway pipeline without going through axum.
162pub async fn dispatch_request(
163    state: &AppState,
164    method: Method,
165    uri: Uri,
166    headers: HeaderMap,
167    body: Bytes,
168) -> (Response<Body>, String) {
169    state.request_count.fetch_add(1, Ordering::Relaxed);
170
171    let request_id = uuid::Uuid::new_v4().to_string();
172    let started = Instant::now();
173    let request_size = body.len() as u64;
174
175    debug!(
176        method = %method,
177        uri = %uri,
178        request_id = %request_id,
179        "Incoming request"
180    );
181
182    let mut meta = ProcessMeta {
183        service: String::new(),
184        region: state.default_region.clone(),
185        account_id: state.default_account_id.clone(),
186        access_key: None,
187    };
188
189    let outcome = process_request(
190        state,
191        &method,
192        &uri,
193        &headers,
194        &body,
195        &request_id,
196        &mut meta,
197    )
198    .await;
199
200    let (status, mut resp_headers, resp_body, operation, error_code) = match outcome {
201        Ok(ProcessOk {
202            status,
203            headers,
204            body,
205            operation,
206        }) => (status, headers, body, Some(operation), None),
207        Err((protocol, error)) => {
208            warn!(
209                error_code = %error.code,
210                error_message = %error.message,
211                request_id = %request_id,
212                "Request failed"
213            );
214            let err_code = error.code.clone();
215            let (status, resp_headers, resp_body) =
216                protocol::serialize_error(protocol, &error, &request_id);
217            (status, resp_headers, resp_body, None, Some(err_code))
218        }
219    };
220    let status_code = status.as_u16();
221    let response_size = resp_body.len() as u64;
222
223    // Capture detail for the inspect drawer before the body is moved into
224    // the response. Bodies are size-capped inside `capture_body`.
225    let body_cap = state.request_details.body_cap();
226    let detail = RequestDetail {
227        id: request_id.clone(),
228        method: method.to_string(),
229        path: uri.path().to_string(),
230        query: uri.query().map(|q| q.to_string()),
231        status_code,
232        request_headers: capture_headers(&headers),
233        response_headers: capture_headers(&resp_headers),
234        request_body: capture_body(&body, body_cap),
235        response_body: capture_body(&resp_body, body_cap),
236    };
237    state.request_details.insert(detail);
238
239    // Extract any X-Awsim-* metadata headers the responding service
240    // attached for the billing meter (e.g. Lambda's per-invocation
241    // memory size, Step Functions' transition count). Pull them off
242    // before draining into the actual HTTP response so they don't
243    // leak to the wire.
244    let memory_mb = resp_headers
245        .remove("x-awsim-memory-mb")
246        .and_then(|v| v.to_str().ok().and_then(|s| s.parse::<u32>().ok()));
247    let state_transitions = resp_headers
248        .remove("x-awsim-state-transitions")
249        .and_then(|v| v.to_str().ok().and_then(|s| s.parse::<u32>().ok()));
250    let character_count = resp_headers
251        .remove("x-awsim-char-count")
252        .and_then(|v| v.to_str().ok().and_then(|s| s.parse::<u64>().ok()));
253
254    let mut builder = Response::builder().status(status);
255    for (key, value) in resp_headers.drain() {
256        if let Some(key) = key {
257            builder = builder.header(key, value);
258        }
259    }
260    let response = builder.body(Body::from(resp_body)).unwrap();
261
262    let duration_ms = started.elapsed().as_secs_f64() * 1000.0;
263    let ts = SystemTime::now()
264        .duration_since(UNIX_EPOCH)
265        .map(|d| d.as_secs_f64())
266        .unwrap_or(0.0);
267    let principal_arn = meta
268        .access_key
269        .as_ref()
270        .map(|ak| format!("arn:aws:iam::{}:access-key/{}", meta.account_id, ak));
271
272    let event = RequestEvent {
273        id: request_id.clone(),
274        ts,
275        method: method.to_string(),
276        path: uri.path().to_string(),
277        service: meta.service,
278        operation,
279        account_id: meta.account_id,
280        region: meta.region,
281        principal_arn,
282        status_code,
283        duration_ms,
284        request_size,
285        response_size,
286        error_code,
287        memory_mb,
288        state_transitions,
289        character_count,
290    };
291    state.events.publish(event);
292
293    (response, request_id)
294}
295
296async fn process_request(
297    state: &AppState,
298    method: &Method,
299    uri: &Uri,
300    headers: &HeaderMap,
301    body: &Bytes,
302    request_id: &str,
303    meta: &mut ProcessMeta,
304) -> Result<ProcessOk, (Protocol, AwsError)> {
305    // 1. Extract service identification from auth header
306    let (service_name, region, account_id, access_key) = extract_service_info(state, headers, uri);
307    meta.service = service_name.clone();
308    meta.region = region.clone();
309    meta.account_id = account_id.clone();
310    meta.access_key = access_key.clone();
311
312    // 2. Find the service handler
313    let handler = state.services.get(&service_name).ok_or_else(|| {
314        let protocol = protocol::detect_protocol(headers, body).unwrap_or(Protocol::RestJson1);
315        (
316            protocol,
317            AwsError::bad_request(
318                "UnknownService",
319                format!("Service '{service_name}' is not registered"),
320            ),
321        )
322    })?;
323
324    let protocol = handler.protocol();
325
326    // 3. Determine effective protocol (use service's declared protocol if detection fails)
327    let detected = protocol::detect_protocol(headers, body).unwrap_or(protocol);
328
329    // 4. Get routes for REST protocols
330    let empty_routes = Vec::new();
331    let routes = state.routes.get(&service_name).unwrap_or(&empty_routes);
332
333    // 5. Parse the request
334    let parsed = protocol::parse_request(detected, method, uri, headers, body, routes)
335        .map_err(|e| (detected, e))?;
336
337    debug!(
338        service = %service_name,
339        operation = %parsed.operation,
340        request_id = %request_id,
341        "Dispatching operation"
342    );
343
344    // 6. Build request context
345    let ctx = crate::router::RequestContext {
346        account_id,
347        region,
348        service: service_name.clone(),
349        access_key,
350        request_id: request_id.to_string(),
351        method: method.to_string(),
352        uri: uri.to_string(),
353        event_bus: Some(state.event_bus.clone()),
354    };
355
356    // 6b. IAM authorization (opt-in via AWSIM_IAM_ENFORCE)
357    if let (Some(action), Some(resource)) = (
358        handler.iam_action(&parsed.operation),
359        handler.iam_resource(&parsed.operation, &parsed.input, &ctx),
360    ) {
361        state
362            .authz
363            .check(&ctx, &action, &resource)
364            .map_err(|e| (detected, e))?;
365    } else {
366        debug!(
367            service = %service_name,
368            operation = %parsed.operation,
369            "Skipping IAM check — handler does not declare action/resource"
370        );
371    }
372
373    let operation = parsed.operation.clone();
374
375    // 6c. Chaos injection — sleep + optionally short-circuit with a
376    // synthetic AWS error before the handler runs. Empty engine is a
377    // no-op fast path so the cost is negligible when chaos is off.
378    if let Some(outcome) = state.chaos.evaluate(&service_name, Some(&operation)) {
379        if let Some(delay) = outcome.latency {
380            tokio::time::sleep(delay).await;
381        }
382        state
383            .chaos
384            .record_injection(&outcome.rule_id, &service_name, Some(&operation));
385        if let Some(err) = outcome.error {
386            let status = axum::http::StatusCode::from_u16(err.status)
387                .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);
388            let error_type = if status.is_server_error() {
389                crate::error::ErrorType::Receiver
390            } else {
391                crate::error::ErrorType::Sender
392            };
393            let aws_err = AwsError {
394                status,
395                code: err.code,
396                message: err.message,
397                error_type,
398                extras: None,
399            };
400            return Err((detected, aws_err));
401        }
402    }
403
404    // 7. Dispatch to service handler
405    let result = handler
406        .handle(&parsed.operation, parsed.input, &ctx)
407        .await
408        .map_err(|e| (detected, e))?;
409
410    // 8. Serialize response using the *detected* protocol so that the wire
411    // format matches what the client expects.  A client that sends an
412    // awsQuery (form-encoded) request expects an XML response, even if the
413    // service declares AwsJson as its primary protocol.
414    let (status, headers, body) =
415        protocol::serialize_response(detected, &parsed.operation, &result, request_id);
416    Ok(ProcessOk {
417        status,
418        headers,
419        body,
420        operation,
421    })
422}
423
424/// Extract service name, region, account ID, and access key from the request.
425fn extract_service_info(
426    state: &AppState,
427    headers: &HeaderMap,
428    uri: &Uri,
429) -> (String, String, String, Option<String>) {
430    // Try Authorization header first
431    if let Some(auth_header) = headers.get("authorization").and_then(|v| v.to_str().ok())
432        && let Some(creds) = auth::parse_authorization(auth_header)
433    {
434        return (
435            creds.service,
436            creds.region,
437            state.default_account_id.clone(),
438            Some(creds.access_key),
439        );
440    }
441
442    // Try X-Amz-Target header (for awsJson services)
443    if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok())
444        && let Some(service) = resolve_service_from_target(target)
445    {
446        return (
447            service,
448            state.default_region.clone(),
449            state.default_account_id.clone(),
450            None,
451        );
452    }
453
454    // Try Host header
455    if let Some(host) = headers.get("host").and_then(|v| v.to_str().ok())
456        && let Some(service) = extract_service_from_host(host)
457    {
458        return (
459            service,
460            state.default_region.clone(),
461            state.default_account_id.clone(),
462            None,
463        );
464    }
465
466    // Check for pre-signed URL query parameters (SigV4 in query string)
467    if let Some(query) = uri.query()
468        && query.contains("X-Amz-Credential")
469        && let Some(cred_start) = query.find("X-Amz-Credential=")
470    {
471        let cred_val = &query[cred_start + 17..];
472        let cred_end = cred_val.find('&').unwrap_or(cred_val.len());
473        let cred = &cred_val[..cred_end];
474        let cred_decoded = cred.replace("%2F", "/");
475        let parts: Vec<&str> = cred_decoded.split('/').collect();
476        if parts.len() >= 4 {
477            return (
478                parts[3].to_string(),
479                parts[2].to_string(),
480                state.default_account_id.clone(),
481                Some(parts[0].to_string()),
482            );
483        }
484    }
485
486    // Try path-based detection as last resort (for REST services called without auth)
487    let path = uri.path();
488    if let Some(service) = resolve_service_from_path(path) {
489        return (
490            service,
491            state.default_region.clone(),
492            state.default_account_id.clone(),
493            None,
494        );
495    }
496
497    // Fallback: log what we received so we can diagnose routing failures
498    warn!(
499        auth = ?headers.get("authorization").map(|v| v.to_str().unwrap_or("<non-utf8>")),
500        target = ?headers.get("x-amz-target").map(|v| v.to_str().unwrap_or("<non-utf8>")),
501        host = ?headers.get("host").map(|v| v.to_str().unwrap_or("<non-utf8>")),
502        path = %path,
503        "Could not determine service — falling back to 'unknown'"
504    );
505    (
506        "unknown".to_string(),
507        state.default_region.clone(),
508        state.default_account_id.clone(),
509        None,
510    )
511}
512
513/// Map X-Amz-Target prefixes to service signing names.
514fn resolve_service_from_target(target: &str) -> Option<String> {
515    let prefix = target.split('.').next()?;
516    let service = match prefix {
517        // Core services
518        p if p.starts_with("DynamoDB") => "dynamodb",
519        p if p.starts_with("AmazonSQS") => "sqs",
520        p if p.starts_with("AmazonSNS") => "sns",
521        p if p.starts_with("TrentService") => "kms",
522        p if p.starts_with("secretsmanager") => "secretsmanager",
523        p if p.starts_with("AmazonSSM") => "ssm",
524        p if p.starts_with("Logs") => "logs",
525        p if p.starts_with("Kinesis") => "kinesis",
526        p if p.starts_with("AWSStepFunctions") => "states",
527        p if p.starts_with("AWSEvents") => "events",
528        // Auth
529        p if p.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
530        p if p.starts_with("AWSCognitoIdentityService") => "cognito-identity",
531        // Containers
532        p if p.starts_with("AmazonEC2ContainerServiceV2") => "ecs",
533        p if p.starts_with("AmazonEC2ContainerRegistry") => "ecr",
534        // Data/Analytics
535        p if p.starts_with("AmazonAthena") => "athena",
536        p if p.starts_with("AWSGlue") => "glue",
537        // Security
538        p if p.starts_with("CertificateManager") => "acm",
539        p if p.starts_with("AWSWAF") => "wafv2",
540        p if p.starts_with("Comprehend") => "comprehend",
541        p if p.starts_with("kendra") => "kendra",
542        // Management & audit
543        p if p.starts_with("AWSOrganizationsV") => "organizations",
544        p if p.starts_with("CloudTrail_") => "cloudtrail",
545        // Streaming
546        p if p.starts_with("Firehose_") => "firehose",
547        // Cross-service tagging
548        p if p.starts_with("ResourceGroupsTaggingAPI") => "tagging",
549        // Auto scaling
550        p if p.starts_with("AnyScaleFrontendService") => "application-autoscaling",
551        // Cloud Map (Service Discovery)
552        p if p.starts_with("Route53AutoNaming_v") => "servicediscovery",
553        // MemoryDB
554        p if p.starts_with("AmazonMemoryDB") => "memorydb",
555        _ => return None,
556    };
557    Some(service.to_string())
558}
559
560/// Extract service name from Host header.
561/// e.g., "s3.us-east-1.localhost" → "s3"
562/// e.g., "sqs.us-east-1.amazonaws.com" → "sqs"
563fn extract_service_from_host(host: &str) -> Option<String> {
564    // Remove port
565    let host = host.split(':').next().unwrap_or(host);
566    let parts: Vec<&str> = host.split('.').collect();
567    if parts.len() >= 2 {
568        let first = parts[0];
569        // Skip if it looks like a bucket name (for S3 virtual-hosted style)
570        if !first.contains('-')
571            || [
572                "s3",
573                "sqs",
574                "sns",
575                "dynamodb",
576                "lambda",
577                "iam",
578                "sts",
579                "kms",
580                "logs",
581                "events",
582                "states",
583                "ssm",
584                "secretsmanager",
585                "execute-api",
586                "cognito-idp",
587                "cognito-identity",
588                "tagging",
589            ]
590            .contains(&first)
591        {
592            return Some(first.to_string());
593        }
594    }
595    None
596}
597
598/// Last-resort: guess the service from the URI path pattern.
599/// This handles REST-protocol services when no auth header is present
600/// (e.g., requests from the admin console).
601fn resolve_service_from_path(path: &str) -> Option<String> {
602    let service = match path {
603        // Lambda
604        p if p.starts_with("/2015-03-31/functions") || p.starts_with("/2018-10-31/layers") => {
605            "lambda"
606        }
607        // API Gateway v2
608        p if p.starts_with("/v2/apis") => "execute-api",
609        // SES v2
610        p if p.starts_with("/v2/email") => "ses",
611        // Route53
612        p if p.starts_with("/2013-04-01/hostedzone")
613            || p.starts_with("/2013-04-01/healthcheck")
614            || p.starts_with("/2013-04-01/tags") =>
615        {
616            "route53"
617        }
618        // CloudFront
619        p if p.starts_with("/2020-05-31/distribution")
620            || p.starts_with("/2020-05-31/origin-access-control")
621            || p.starts_with("/2020-05-31/cache-policy")
622            || p.starts_with("/2020-05-31/tagging") =>
623        {
624            "cloudfront"
625        }
626        // AppSync
627        p if p.starts_with("/v1/apis") => "appsync",
628        // Bedrock
629        p if p.starts_with("/foundation-models")
630            || p.starts_with("/guardrails")
631            || p.starts_with("/model-customization") =>
632        {
633            "bedrock"
634        }
635        // Bedrock Runtime
636        p if p.starts_with("/model/") => "bedrock-runtime",
637        // EventBridge Scheduler
638        p if p.starts_with("/schedules") || p.starts_with("/schedule-groups") => "scheduler",
639        // EKS
640        p if p.starts_with("/clusters") || p == "/tags" || p.starts_with("/tags/") => "eks",
641        // S3 (catch-all — any path starting with / that doesn't match above could be S3)
642        // Don't add S3 here as it would catch everything
643        _ => return None,
644    };
645    Some(service.to_string())
646}
647
648#[cfg(test)]
649mod browser_probe_tests {
650    use super::*;
651
652    #[test]
653    fn matches_known_probes() {
654        assert!(is_browser_probe(&Method::GET, "/favicon.ico"));
655        assert!(is_browser_probe(
656            &Method::GET,
657            "/.well-known/appspecific/com.chrome.devtools.json"
658        ));
659        assert!(is_browser_probe(&Method::GET, "/robots.txt"));
660        assert!(is_browser_probe(&Method::GET, "/apple-touch-icon.png"));
661    }
662
663    #[test]
664    fn ignores_non_get_methods() {
665        // S3 PutObject to /favicon.ico would be a real (if weird) call.
666        assert!(!is_browser_probe(&Method::PUT, "/favicon.ico"));
667        assert!(!is_browser_probe(&Method::POST, "/favicon.ico"));
668    }
669
670    #[test]
671    fn ignores_unknown_paths() {
672        assert!(!is_browser_probe(&Method::GET, "/"));
673        assert!(!is_browser_probe(&Method::GET, "/some-bucket/key"));
674        assert!(!is_browser_probe(&Method::GET, "/_awsim/stats"));
675    }
676}