Skip to main content

awsim_core/
gateway.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::{Instant, SystemTime, UNIX_EPOCH};
5
6use axum::body::Body;
7use axum::extract::State;
8use axum::http::{HeaderMap, Method, Response, StatusCode, Uri};
9use bytes::Bytes;
10use tracing::{debug, info, warn};
11
12use crate::ServiceHandler;
13use crate::auth;
14use crate::authz::AuthzEngine;
15use crate::body_store::BodyStore;
16use crate::error::AwsError;
17use crate::events::EventBus;
18use crate::protocol::{self, Protocol, RouteDefinition};
19use crate::request_detail::{RequestDetail, RequestDetailStore, capture_body, capture_headers};
20use crate::request_event::{RequestEvent, RequestEventBus};
21
22#[derive(Clone)]
23pub struct BodyStoreHandle {
24    pub service_name: String,
25    pub groups: Vec<String>,
26    pub body_store: Arc<BodyStore>,
27}
28
29/// Shared application state passed to all request handlers.
30#[derive(Clone)]
31pub struct AppState {
32    /// Registered service handlers, keyed by signing name.
33    pub services: Arc<HashMap<String, Arc<dyn ServiceHandler>>>,
34    /// Route definitions for REST-protocol services, keyed by signing name.
35    pub routes: Arc<HashMap<String, Vec<RouteDefinition>>>,
36    /// Default AWS region.
37    pub default_region: String,
38    /// Default AWS account ID.
39    pub default_account_id: String,
40    /// Internal event bus for cross-service fan-out (SNS→SQS, etc.).
41    pub event_bus: EventBus,
42    /// Total number of AWS API requests handled since startup.
43    pub request_count: Arc<AtomicU64>,
44    /// Server startup time.
45    pub start_time: std::time::Instant,
46    /// IAM authorization engine — opt-in via AWSIM_IAM_ENFORCE=true.
47    pub authz: Arc<AuthzEngine>,
48    /// Per-service `BodyStore` handles, populated when persistence is enabled.
49    pub body_stores: Arc<Vec<BodyStoreHandle>>,
50    /// Persistence root directory, when persistence is enabled.
51    pub data_dir: Option<Arc<std::path::PathBuf>>,
52    /// Broadcast bus for per-request observability events (consumed by SSE).
53    pub events: RequestEventBus,
54    /// Ring buffer of recent per-request detail captures (headers + bodies).
55    pub request_details: RequestDetailStore,
56}
57
58impl AppState {
59    pub fn new(default_region: String, default_account_id: String) -> Self {
60        Self {
61            services: Arc::new(HashMap::new()),
62            routes: Arc::new(HashMap::new()),
63            default_region,
64            default_account_id,
65            event_bus: EventBus::new(),
66            request_count: Arc::new(AtomicU64::new(0)),
67            start_time: std::time::Instant::now(),
68            authz: Arc::new(AuthzEngine::from_env()),
69            body_stores: Arc::new(Vec::new()),
70            data_dir: None,
71            events: RequestEventBus::new(),
72            request_details: RequestDetailStore::default(),
73        }
74    }
75
76    /// Register a service handler.
77    pub fn register(&mut self, handler: Arc<dyn ServiceHandler>, routes: Vec<RouteDefinition>) {
78        let signing_name = handler.signing_name().to_string();
79        let service_name = handler.service_name().to_string();
80
81        info!(
82            service = %service_name,
83            signing_name = %signing_name,
84            protocol = ?handler.protocol(),
85            routes = routes.len(),
86            "Registered service"
87        );
88
89        Arc::get_mut(&mut self.services)
90            .unwrap()
91            .insert(signing_name.clone(), handler);
92
93        if !routes.is_empty() {
94            Arc::get_mut(&mut self.routes)
95                .unwrap()
96                .insert(signing_name, routes);
97        }
98    }
99}
100
101struct ProcessOk {
102    status: StatusCode,
103    headers: HeaderMap,
104    body: Bytes,
105    operation: String,
106}
107
108struct ProcessMeta {
109    service: String,
110    region: String,
111    account_id: String,
112    access_key: Option<String>,
113}
114
115/// Main request handler — all AWS API requests funnel through here.
116pub async fn handle_request(
117    State(state): State<AppState>,
118    method: Method,
119    uri: Uri,
120    headers: HeaderMap,
121    body: Bytes,
122) -> Response<Body> {
123    // Short-circuit common browser probes (favicon, devtools well-known
124    // path) before the AWS dispatch pipeline runs. They are not API
125    // calls — silently 204'ing keeps them out of the request log and
126    // out of the inspect drawer.
127    if is_browser_probe(&method, uri.path()) {
128        return Response::builder()
129            .status(StatusCode::NO_CONTENT)
130            .body(Body::empty())
131            .unwrap();
132    }
133    let (response, _id) = dispatch_request(&state, method, uri, headers, body).await;
134    response
135}
136
137/// Recognise the handful of unsolicited paths browsers hit when you point
138/// them at AWSim's port — they're not AWS requests and shouldn't appear
139/// in logs or stats. Conservative on purpose: only paths we've actually
140/// seen in real traces are listed.
141fn is_browser_probe(method: &Method, path: &str) -> bool {
142    if method != Method::GET {
143        return false;
144    }
145    matches!(
146        path,
147        "/favicon.ico"
148            | "/apple-touch-icon.png"
149            | "/apple-touch-icon-precomposed.png"
150            | "/robots.txt"
151            | "/.well-known/appspecific/com.chrome.devtools.json"
152    )
153}
154
155/// Same as `handle_request`, but takes the state by reference and returns
156/// the generated request id alongside the response. Lets internal callers
157/// (replay, etc.) drive the gateway pipeline without going through axum.
158pub async fn dispatch_request(
159    state: &AppState,
160    method: Method,
161    uri: Uri,
162    headers: HeaderMap,
163    body: Bytes,
164) -> (Response<Body>, String) {
165    state.request_count.fetch_add(1, Ordering::Relaxed);
166
167    let request_id = uuid::Uuid::new_v4().to_string();
168    let started = Instant::now();
169    let request_size = body.len() as u64;
170
171    debug!(
172        method = %method,
173        uri = %uri,
174        request_id = %request_id,
175        "Incoming request"
176    );
177
178    let mut meta = ProcessMeta {
179        service: String::new(),
180        region: state.default_region.clone(),
181        account_id: state.default_account_id.clone(),
182        access_key: None,
183    };
184
185    let outcome = process_request(
186        state,
187        &method,
188        &uri,
189        &headers,
190        &body,
191        &request_id,
192        &mut meta,
193    )
194    .await;
195
196    let (status, resp_headers, resp_body, operation, error_code) = match outcome {
197        Ok(ProcessOk {
198            status,
199            headers,
200            body,
201            operation,
202        }) => (status, headers, body, Some(operation), None),
203        Err((protocol, error)) => {
204            warn!(
205                error_code = %error.code,
206                error_message = %error.message,
207                request_id = %request_id,
208                "Request failed"
209            );
210            let err_code = error.code.clone();
211            let (status, resp_headers, resp_body) =
212                protocol::serialize_error(protocol, &error, &request_id);
213            (status, resp_headers, resp_body, None, Some(err_code))
214        }
215    };
216    let status_code = status.as_u16();
217    let response_size = resp_body.len() as u64;
218
219    // Capture detail for the inspect drawer before the body is moved into
220    // the response. Bodies are size-capped inside `capture_body`.
221    let body_cap = state.request_details.body_cap();
222    let detail = RequestDetail {
223        id: request_id.clone(),
224        method: method.to_string(),
225        path: uri.path().to_string(),
226        query: uri.query().map(|q| q.to_string()),
227        status_code,
228        request_headers: capture_headers(&headers),
229        response_headers: capture_headers(&resp_headers),
230        request_body: capture_body(&body, body_cap),
231        response_body: capture_body(&resp_body, body_cap),
232    };
233    state.request_details.insert(detail);
234
235    let mut builder = Response::builder().status(status);
236    let mut resp_headers = resp_headers;
237    for (key, value) in resp_headers.drain() {
238        if let Some(key) = key {
239            builder = builder.header(key, value);
240        }
241    }
242    let response = builder.body(Body::from(resp_body)).unwrap();
243
244    let duration_ms = started.elapsed().as_secs_f64() * 1000.0;
245    let ts = SystemTime::now()
246        .duration_since(UNIX_EPOCH)
247        .map(|d| d.as_secs_f64())
248        .unwrap_or(0.0);
249    let principal_arn = meta
250        .access_key
251        .as_ref()
252        .map(|ak| format!("arn:aws:iam::{}:access-key/{}", meta.account_id, ak));
253
254    let event = RequestEvent {
255        id: request_id.clone(),
256        ts,
257        method: method.to_string(),
258        path: uri.path().to_string(),
259        service: meta.service,
260        operation,
261        account_id: meta.account_id,
262        region: meta.region,
263        principal_arn,
264        status_code,
265        duration_ms,
266        request_size,
267        response_size,
268        error_code,
269    };
270    state.events.publish(event);
271
272    (response, request_id)
273}
274
275async fn process_request(
276    state: &AppState,
277    method: &Method,
278    uri: &Uri,
279    headers: &HeaderMap,
280    body: &Bytes,
281    request_id: &str,
282    meta: &mut ProcessMeta,
283) -> Result<ProcessOk, (Protocol, AwsError)> {
284    // 1. Extract service identification from auth header
285    let (service_name, region, account_id, access_key) = extract_service_info(state, headers, uri);
286    meta.service = service_name.clone();
287    meta.region = region.clone();
288    meta.account_id = account_id.clone();
289    meta.access_key = access_key.clone();
290
291    // 2. Find the service handler
292    let handler = state.services.get(&service_name).ok_or_else(|| {
293        let protocol = protocol::detect_protocol(headers, body).unwrap_or(Protocol::RestJson1);
294        (
295            protocol,
296            AwsError::bad_request(
297                "UnknownService",
298                format!("Service '{service_name}' is not registered"),
299            ),
300        )
301    })?;
302
303    let protocol = handler.protocol();
304
305    // 3. Determine effective protocol (use service's declared protocol if detection fails)
306    let detected = protocol::detect_protocol(headers, body).unwrap_or(protocol);
307
308    // 4. Get routes for REST protocols
309    let empty_routes = Vec::new();
310    let routes = state.routes.get(&service_name).unwrap_or(&empty_routes);
311
312    // 5. Parse the request
313    let parsed = protocol::parse_request(detected, method, uri, headers, body, routes)
314        .map_err(|e| (detected, e))?;
315
316    debug!(
317        service = %service_name,
318        operation = %parsed.operation,
319        request_id = %request_id,
320        "Dispatching operation"
321    );
322
323    // 6. Build request context
324    let ctx = crate::router::RequestContext {
325        account_id,
326        region,
327        service: service_name.clone(),
328        access_key,
329        request_id: request_id.to_string(),
330        method: method.to_string(),
331        uri: uri.to_string(),
332        event_bus: Some(state.event_bus.clone()),
333    };
334
335    // 6b. IAM authorization (opt-in via AWSIM_IAM_ENFORCE)
336    if let (Some(action), Some(resource)) = (
337        handler.iam_action(&parsed.operation),
338        handler.iam_resource(&parsed.operation, &parsed.input, &ctx),
339    ) {
340        state
341            .authz
342            .check(&ctx, &action, &resource)
343            .map_err(|e| (detected, e))?;
344    } else {
345        debug!(
346            service = %service_name,
347            operation = %parsed.operation,
348            "Skipping IAM check — handler does not declare action/resource"
349        );
350    }
351
352    let operation = parsed.operation.clone();
353
354    // 7. Dispatch to service handler
355    let result = handler
356        .handle(&parsed.operation, parsed.input, &ctx)
357        .await
358        .map_err(|e| (detected, e))?;
359
360    // 8. Serialize response using the *detected* protocol so that the wire
361    // format matches what the client expects.  A client that sends an
362    // awsQuery (form-encoded) request expects an XML response, even if the
363    // service declares AwsJson as its primary protocol.
364    let (status, headers, body) =
365        protocol::serialize_response(detected, &parsed.operation, &result, request_id);
366    Ok(ProcessOk {
367        status,
368        headers,
369        body,
370        operation,
371    })
372}
373
374/// Extract service name, region, account ID, and access key from the request.
375fn extract_service_info(
376    state: &AppState,
377    headers: &HeaderMap,
378    uri: &Uri,
379) -> (String, String, String, Option<String>) {
380    // Try Authorization header first
381    if let Some(auth_header) = headers.get("authorization").and_then(|v| v.to_str().ok())
382        && let Some(creds) = auth::parse_authorization(auth_header)
383    {
384        return (
385            creds.service,
386            creds.region,
387            state.default_account_id.clone(),
388            Some(creds.access_key),
389        );
390    }
391
392    // Try X-Amz-Target header (for awsJson services)
393    if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok())
394        && let Some(service) = resolve_service_from_target(target)
395    {
396        return (
397            service,
398            state.default_region.clone(),
399            state.default_account_id.clone(),
400            None,
401        );
402    }
403
404    // Try Host header
405    if let Some(host) = headers.get("host").and_then(|v| v.to_str().ok())
406        && let Some(service) = extract_service_from_host(host)
407    {
408        return (
409            service,
410            state.default_region.clone(),
411            state.default_account_id.clone(),
412            None,
413        );
414    }
415
416    // Check for pre-signed URL query parameters (SigV4 in query string)
417    if let Some(query) = uri.query()
418        && query.contains("X-Amz-Credential")
419        && let Some(cred_start) = query.find("X-Amz-Credential=")
420    {
421        let cred_val = &query[cred_start + 17..];
422        let cred_end = cred_val.find('&').unwrap_or(cred_val.len());
423        let cred = &cred_val[..cred_end];
424        let cred_decoded = cred.replace("%2F", "/");
425        let parts: Vec<&str> = cred_decoded.split('/').collect();
426        if parts.len() >= 4 {
427            return (
428                parts[3].to_string(),
429                parts[2].to_string(),
430                state.default_account_id.clone(),
431                Some(parts[0].to_string()),
432            );
433        }
434    }
435
436    // Try path-based detection as last resort (for REST services called without auth)
437    let path = uri.path();
438    if let Some(service) = resolve_service_from_path(path) {
439        return (
440            service,
441            state.default_region.clone(),
442            state.default_account_id.clone(),
443            None,
444        );
445    }
446
447    // Fallback: log what we received so we can diagnose routing failures
448    warn!(
449        auth = ?headers.get("authorization").map(|v| v.to_str().unwrap_or("<non-utf8>")),
450        target = ?headers.get("x-amz-target").map(|v| v.to_str().unwrap_or("<non-utf8>")),
451        host = ?headers.get("host").map(|v| v.to_str().unwrap_or("<non-utf8>")),
452        path = %path,
453        "Could not determine service — falling back to 'unknown'"
454    );
455    (
456        "unknown".to_string(),
457        state.default_region.clone(),
458        state.default_account_id.clone(),
459        None,
460    )
461}
462
463/// Map X-Amz-Target prefixes to service signing names.
464fn resolve_service_from_target(target: &str) -> Option<String> {
465    let prefix = target.split('.').next()?;
466    let service = match prefix {
467        // Core services
468        p if p.starts_with("DynamoDB") => "dynamodb",
469        p if p.starts_with("AmazonSQS") => "sqs",
470        p if p.starts_with("AmazonSNS") => "sns",
471        p if p.starts_with("TrentService") => "kms",
472        p if p.starts_with("secretsmanager") => "secretsmanager",
473        p if p.starts_with("AmazonSSM") => "ssm",
474        p if p.starts_with("Logs") => "logs",
475        p if p.starts_with("Kinesis") => "kinesis",
476        p if p.starts_with("AWSStepFunctions") => "states",
477        p if p.starts_with("AWSEvents") => "events",
478        // Auth
479        p if p.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
480        p if p.starts_with("AWSCognitoIdentityService") => "cognito-identity",
481        // Containers
482        p if p.starts_with("AmazonEC2ContainerServiceV2") => "ecs",
483        p if p.starts_with("AmazonEC2ContainerRegistry") => "ecr",
484        // Data/Analytics
485        p if p.starts_with("AmazonAthena") => "athena",
486        p if p.starts_with("AWSGlue") => "glue",
487        // Security
488        p if p.starts_with("CertificateManager") => "acm",
489        p if p.starts_with("AWSWAF") => "wafv2",
490        p if p.starts_with("Comprehend") => "comprehend",
491        p if p.starts_with("kendra") => "kendra",
492        // Management & audit
493        p if p.starts_with("AWSOrganizationsV") => "organizations",
494        p if p.starts_with("CloudTrail_") => "cloudtrail",
495        // Streaming
496        p if p.starts_with("Firehose_") => "firehose",
497        // Cross-service tagging
498        p if p.starts_with("ResourceGroupsTaggingAPI") => "tagging",
499        // Auto scaling
500        p if p.starts_with("AnyScaleFrontendService") => "application-autoscaling",
501        // Cloud Map (Service Discovery)
502        p if p.starts_with("Route53AutoNaming_v") => "servicediscovery",
503        // MemoryDB
504        p if p.starts_with("AmazonMemoryDB") => "memorydb",
505        _ => return None,
506    };
507    Some(service.to_string())
508}
509
510/// Extract service name from Host header.
511/// e.g., "s3.us-east-1.localhost" → "s3"
512/// e.g., "sqs.us-east-1.amazonaws.com" → "sqs"
513fn extract_service_from_host(host: &str) -> Option<String> {
514    // Remove port
515    let host = host.split(':').next().unwrap_or(host);
516    let parts: Vec<&str> = host.split('.').collect();
517    if parts.len() >= 2 {
518        let first = parts[0];
519        // Skip if it looks like a bucket name (for S3 virtual-hosted style)
520        if !first.contains('-')
521            || [
522                "s3",
523                "sqs",
524                "sns",
525                "dynamodb",
526                "lambda",
527                "iam",
528                "sts",
529                "kms",
530                "logs",
531                "events",
532                "states",
533                "ssm",
534                "secretsmanager",
535                "execute-api",
536                "cognito-idp",
537                "cognito-identity",
538                "tagging",
539            ]
540            .contains(&first)
541        {
542            return Some(first.to_string());
543        }
544    }
545    None
546}
547
548/// Last-resort: guess the service from the URI path pattern.
549/// This handles REST-protocol services when no auth header is present
550/// (e.g., requests from the admin console).
551fn resolve_service_from_path(path: &str) -> Option<String> {
552    let service = match path {
553        // Lambda
554        p if p.starts_with("/2015-03-31/functions") || p.starts_with("/2018-10-31/layers") => {
555            "lambda"
556        }
557        // API Gateway v2
558        p if p.starts_with("/v2/apis") => "execute-api",
559        // SES v2
560        p if p.starts_with("/v2/email") => "ses",
561        // Route53
562        p if p.starts_with("/2013-04-01/hostedzone")
563            || p.starts_with("/2013-04-01/healthcheck")
564            || p.starts_with("/2013-04-01/tags") =>
565        {
566            "route53"
567        }
568        // CloudFront
569        p if p.starts_with("/2020-05-31/distribution")
570            || p.starts_with("/2020-05-31/origin-access-control")
571            || p.starts_with("/2020-05-31/cache-policy")
572            || p.starts_with("/2020-05-31/tagging") =>
573        {
574            "cloudfront"
575        }
576        // AppSync
577        p if p.starts_with("/v1/apis") => "appsync",
578        // Bedrock
579        p if p.starts_with("/foundation-models")
580            || p.starts_with("/guardrails")
581            || p.starts_with("/model-customization") =>
582        {
583            "bedrock"
584        }
585        // Bedrock Runtime
586        p if p.starts_with("/model/") => "bedrock-runtime",
587        // EventBridge Scheduler
588        p if p.starts_with("/schedules") || p.starts_with("/schedule-groups") => "scheduler",
589        // EKS
590        p if p.starts_with("/clusters") || p == "/tags" || p.starts_with("/tags/") => "eks",
591        // S3 (catch-all — any path starting with / that doesn't match above could be S3)
592        // Don't add S3 here as it would catch everything
593        _ => return None,
594    };
595    Some(service.to_string())
596}
597
598#[cfg(test)]
599mod browser_probe_tests {
600    use super::*;
601
602    #[test]
603    fn matches_known_probes() {
604        assert!(is_browser_probe(&Method::GET, "/favicon.ico"));
605        assert!(is_browser_probe(
606            &Method::GET,
607            "/.well-known/appspecific/com.chrome.devtools.json"
608        ));
609        assert!(is_browser_probe(&Method::GET, "/robots.txt"));
610        assert!(is_browser_probe(&Method::GET, "/apple-touch-icon.png"));
611    }
612
613    #[test]
614    fn ignores_non_get_methods() {
615        // S3 PutObject to /favicon.ico would be a real (if weird) call.
616        assert!(!is_browser_probe(&Method::PUT, "/favicon.ico"));
617        assert!(!is_browser_probe(&Method::POST, "/favicon.ico"));
618    }
619
620    #[test]
621    fn ignores_unknown_paths() {
622        assert!(!is_browser_probe(&Method::GET, "/"));
623        assert!(!is_browser_probe(&Method::GET, "/some-bucket/key"));
624        assert!(!is_browser_probe(&Method::GET, "/_awsim/stats"));
625    }
626}