Skip to main content

awsim_core/
gateway.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use axum::body::Body;
6use axum::extract::State;
7use axum::http::{HeaderMap, Method, Response, StatusCode, Uri};
8use bytes::Bytes;
9use tracing::{debug, info, warn};
10
11use crate::auth;
12use crate::authz::AuthzEngine;
13use crate::error::AwsError;
14use crate::events::EventBus;
15use crate::protocol::{self, Protocol, RouteDefinition};
16use crate::ServiceHandler;
17
18/// Shared application state passed to all request handlers.
19#[derive(Clone)]
20pub struct AppState {
21    /// Registered service handlers, keyed by signing name.
22    pub services: Arc<HashMap<String, Arc<dyn ServiceHandler>>>,
23    /// Route definitions for REST-protocol services, keyed by signing name.
24    pub routes: Arc<HashMap<String, Vec<RouteDefinition>>>,
25    /// Default AWS region.
26    pub default_region: String,
27    /// Default AWS account ID.
28    pub default_account_id: String,
29    /// Internal event bus for cross-service fan-out (SNS→SQS, etc.).
30    pub event_bus: EventBus,
31    /// Total number of AWS API requests handled since startup.
32    pub request_count: Arc<AtomicU64>,
33    /// Server startup time.
34    pub start_time: std::time::Instant,
35    /// IAM authorization engine — opt-in via AWSIM_IAM_ENFORCE=true.
36    pub authz: Arc<AuthzEngine>,
37}
38
39impl AppState {
40    pub fn new(default_region: String, default_account_id: String) -> Self {
41        Self {
42            services: Arc::new(HashMap::new()),
43            routes: Arc::new(HashMap::new()),
44            default_region,
45            default_account_id,
46            event_bus: EventBus::new(),
47            request_count: Arc::new(AtomicU64::new(0)),
48            start_time: std::time::Instant::now(),
49            authz: Arc::new(AuthzEngine::from_env()),
50        }
51    }
52
53    /// Register a service handler.
54    pub fn register(&mut self, handler: Arc<dyn ServiceHandler>, routes: Vec<RouteDefinition>) {
55        let signing_name = handler.signing_name().to_string();
56        let service_name = handler.service_name().to_string();
57
58        info!(
59            service = %service_name,
60            signing_name = %signing_name,
61            protocol = ?handler.protocol(),
62            routes = routes.len(),
63            "Registered service"
64        );
65
66        Arc::get_mut(&mut self.services)
67            .unwrap()
68            .insert(signing_name.clone(), handler);
69
70        if !routes.is_empty() {
71            Arc::get_mut(&mut self.routes)
72                .unwrap()
73                .insert(signing_name, routes);
74        }
75    }
76}
77
78/// Main request handler — all AWS API requests funnel through here.
79pub async fn handle_request(
80    State(state): State<AppState>,
81    method: Method,
82    uri: Uri,
83    headers: HeaderMap,
84    body: Bytes,
85) -> Response<Body> {
86    state.request_count.fetch_add(1, Ordering::Relaxed);
87
88    let request_id = uuid::Uuid::new_v4().to_string();
89
90    debug!(
91        method = %method,
92        uri = %uri,
93        request_id = %request_id,
94        "Incoming request"
95    );
96
97    match process_request(&state, &method, &uri, &headers, &body, &request_id).await {
98        Ok((status, mut resp_headers, resp_body)) => {
99            let mut builder = Response::builder().status(status);
100            for (key, value) in resp_headers.drain() {
101                if let Some(key) = key {
102                    builder = builder.header(key, value);
103                }
104            }
105            builder.body(Body::from(resp_body)).unwrap()
106        }
107        Err((protocol, error)) => {
108            warn!(
109                error_code = %error.code,
110                error_message = %error.message,
111                request_id = %request_id,
112                "Request failed"
113            );
114            let (status, mut resp_headers, resp_body) =
115                protocol::serialize_error(protocol, &error, &request_id);
116            let mut builder = Response::builder().status(status);
117            for (key, value) in resp_headers.drain() {
118                if let Some(key) = key {
119                    builder = builder.header(key, value);
120                }
121            }
122            builder.body(Body::from(resp_body)).unwrap()
123        }
124    }
125}
126
127async fn process_request(
128    state: &AppState,
129    method: &Method,
130    uri: &Uri,
131    headers: &HeaderMap,
132    body: &Bytes,
133    request_id: &str,
134) -> Result<(StatusCode, HeaderMap, Bytes), (Protocol, AwsError)> {
135    // 1. Extract service identification from auth header
136    let (service_name, region, account_id, access_key) =
137        extract_service_info(state, headers, uri);
138
139    // 2. Find the service handler
140    let handler = state
141        .services
142        .get(&service_name)
143        .ok_or_else(|| {
144            let protocol = protocol::detect_protocol(headers, body).unwrap_or(Protocol::RestJson1);
145            (
146                protocol,
147                AwsError::bad_request(
148                    "UnknownService",
149                    format!("Service '{service_name}' is not registered"),
150                ),
151            )
152        })?;
153
154    let protocol = handler.protocol();
155
156    // 3. Determine effective protocol (use service's declared protocol if detection fails)
157    let detected = protocol::detect_protocol(headers, body).unwrap_or(protocol);
158
159    // 4. Get routes for REST protocols
160    let empty_routes = Vec::new();
161    let routes = state.routes.get(&service_name).unwrap_or(&empty_routes);
162
163    // 5. Parse the request
164    let parsed = protocol::parse_request(detected, method, uri, headers, body, routes)
165        .map_err(|e| (detected, e))?;
166
167    debug!(
168        service = %service_name,
169        operation = %parsed.operation,
170        request_id = %request_id,
171        "Dispatching operation"
172    );
173
174    // 6. Build request context
175    let ctx = crate::router::RequestContext {
176        account_id,
177        region,
178        service: service_name.clone(),
179        access_key,
180        request_id: request_id.to_string(),
181        method: method.to_string(),
182        uri: uri.to_string(),
183        event_bus: Some(state.event_bus.clone()),
184    };
185
186    // 6b. IAM authorization (opt-in via AWSIM_IAM_ENFORCE)
187    if let (Some(action), Some(resource)) = (
188        handler.iam_action(&parsed.operation),
189        handler.iam_resource(&parsed.operation, &parsed.input, &ctx),
190    ) {
191        state
192            .authz
193            .check(&ctx, &action, &resource)
194            .map_err(|e| (detected, e))?;
195    } else {
196        debug!(
197            service = %service_name,
198            operation = %parsed.operation,
199            "Skipping IAM check — handler does not declare action/resource"
200        );
201    }
202
203    // 7. Dispatch to service handler
204    let result = handler
205        .handle(&parsed.operation, parsed.input, &ctx)
206        .await
207        .map_err(|e| (detected, e))?;
208
209    // 8. Serialize response using the *detected* protocol so that the wire
210    // format matches what the client expects.  A client that sends an
211    // awsQuery (form-encoded) request expects an XML response, even if the
212    // service declares AwsJson as its primary protocol.
213    Ok(protocol::serialize_response(
214        detected,
215        &parsed.operation,
216        &result,
217        request_id,
218    ))
219}
220
221/// Extract service name, region, account ID, and access key from the request.
222fn extract_service_info(
223    state: &AppState,
224    headers: &HeaderMap,
225    uri: &Uri,
226) -> (String, String, String, Option<String>) {
227    // Try Authorization header first
228    if let Some(auth_header) = headers.get("authorization").and_then(|v| v.to_str().ok()) {
229        if let Some(creds) = auth::parse_authorization(auth_header) {
230            return (
231                creds.service,
232                creds.region,
233                state.default_account_id.clone(),
234                Some(creds.access_key),
235            );
236        }
237    }
238
239    // Try X-Amz-Target header (for awsJson services)
240    if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
241        if let Some(service) = resolve_service_from_target(target) {
242            return (
243                service,
244                state.default_region.clone(),
245                state.default_account_id.clone(),
246                None,
247            );
248        }
249    }
250
251    // Try Host header
252    if let Some(host) = headers.get("host").and_then(|v| v.to_str().ok()) {
253        if let Some(service) = extract_service_from_host(host) {
254            return (
255                service,
256                state.default_region.clone(),
257                state.default_account_id.clone(),
258                None,
259            );
260        }
261    }
262
263    // Check for pre-signed URL query parameters (SigV4 in query string)
264    if let Some(query) = uri.query() {
265        if query.contains("X-Amz-Credential") {
266            if let Some(cred_start) = query.find("X-Amz-Credential=") {
267                let cred_val = &query[cred_start + 17..];
268                let cred_end = cred_val.find('&').unwrap_or(cred_val.len());
269                let cred = &cred_val[..cred_end];
270                let cred_decoded = cred.replace("%2F", "/");
271                let parts: Vec<&str> = cred_decoded.split('/').collect();
272                if parts.len() >= 4 {
273                    return (
274                        parts[3].to_string(),
275                        parts[2].to_string(),
276                        state.default_account_id.clone(),
277                        Some(parts[0].to_string()),
278                    );
279                }
280            }
281        }
282    }
283
284    // Try path-based detection as last resort (for REST services called without auth)
285    let path = uri.path();
286    if let Some(service) = resolve_service_from_path(path) {
287        return (
288            service,
289            state.default_region.clone(),
290            state.default_account_id.clone(),
291            None,
292        );
293    }
294
295    // Fallback: log what we received so we can diagnose routing failures
296    warn!(
297        auth = ?headers.get("authorization").map(|v| v.to_str().unwrap_or("<non-utf8>")),
298        target = ?headers.get("x-amz-target").map(|v| v.to_str().unwrap_or("<non-utf8>")),
299        host = ?headers.get("host").map(|v| v.to_str().unwrap_or("<non-utf8>")),
300        path = %path,
301        "Could not determine service — falling back to 'unknown'"
302    );
303    (
304        "unknown".to_string(),
305        state.default_region.clone(),
306        state.default_account_id.clone(),
307        None,
308    )
309}
310
311/// Map X-Amz-Target prefixes to service signing names.
312fn resolve_service_from_target(target: &str) -> Option<String> {
313    let prefix = target.split('.').next()?;
314    let service = match prefix {
315        // Core services
316        p if p.starts_with("DynamoDB") => "dynamodb",
317        p if p.starts_with("AmazonSQS") => "sqs",
318        p if p.starts_with("AmazonSNS") => "sns",
319        p if p.starts_with("TrentService") => "kms",
320        p if p.starts_with("secretsmanager") => "secretsmanager",
321        p if p.starts_with("AmazonSSM") => "ssm",
322        p if p.starts_with("Logs") => "logs",
323        p if p.starts_with("Kinesis") => "kinesis",
324        p if p.starts_with("AWSStepFunctions") => "states",
325        p if p.starts_with("AWSEvents") => "events",
326        // Auth
327        p if p.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
328        p if p.starts_with("AWSCognitoIdentityService") => "cognito-identity",
329        // Containers
330        p if p.starts_with("AmazonEC2ContainerServiceV2") => "ecs",
331        p if p.starts_with("AmazonEC2ContainerRegistry") => "ecr",
332        // Data/Analytics
333        p if p.starts_with("AmazonAthena") => "athena",
334        p if p.starts_with("AWSGlue") => "glue",
335        // Security
336        p if p.starts_with("CertificateManager") => "acm",
337        p if p.starts_with("AWSWAF") => "wafv2",
338        p if p.starts_with("Comprehend") => "comprehend",
339        p if p.starts_with("kendra") => "kendra",
340        // Management & audit
341        p if p.starts_with("AWSOrganizationsV") => "organizations",
342        p if p.starts_with("CloudTrail_") => "cloudtrail",
343        // Streaming
344        p if p.starts_with("Firehose_") => "firehose",
345        _ => return None,
346    };
347    Some(service.to_string())
348}
349
350/// Extract service name from Host header.
351/// e.g., "s3.us-east-1.localhost" → "s3"
352/// e.g., "sqs.us-east-1.amazonaws.com" → "sqs"
353fn extract_service_from_host(host: &str) -> Option<String> {
354    // Remove port
355    let host = host.split(':').next().unwrap_or(host);
356    let parts: Vec<&str> = host.split('.').collect();
357    if parts.len() >= 2 {
358        let first = parts[0];
359        // Skip if it looks like a bucket name (for S3 virtual-hosted style)
360        if !first.contains('-') || ["s3", "sqs", "sns", "dynamodb", "lambda", "iam", "sts", "kms", "logs", "events", "states", "ssm", "secretsmanager", "execute-api", "cognito-idp", "cognito-identity"].contains(&first) {
361            return Some(first.to_string());
362        }
363    }
364    None
365}
366
367/// Last-resort: guess the service from the URI path pattern.
368/// This handles REST-protocol services when no auth header is present
369/// (e.g., requests from the admin console).
370fn resolve_service_from_path(path: &str) -> Option<String> {
371    let service = match path {
372        // Lambda
373        p if p.starts_with("/2015-03-31/functions") || p.starts_with("/2018-10-31/layers") => "lambda",
374        // API Gateway v2
375        p if p.starts_with("/v2/apis") => "execute-api",
376        // SES v2
377        p if p.starts_with("/v2/email") => "ses",
378        // Route53
379        p if p.starts_with("/2013-04-01/hostedzone") || p.starts_with("/2013-04-01/healthcheck") || p.starts_with("/2013-04-01/tags") => "route53",
380        // CloudFront
381        p if p.starts_with("/2020-05-31/distribution") || p.starts_with("/2020-05-31/origin-access-control") || p.starts_with("/2020-05-31/cache-policy") || p.starts_with("/2020-05-31/tagging") => "cloudfront",
382        // AppSync
383        p if p.starts_with("/v1/apis") => "appsync",
384        // Bedrock
385        p if p.starts_with("/foundation-models") || p.starts_with("/guardrails") || p.starts_with("/model-customization") => "bedrock",
386        // Bedrock Runtime
387        p if p.starts_with("/model/") => "bedrock-runtime",
388        // EventBridge Scheduler
389        p if p.starts_with("/schedules") || p.starts_with("/schedule-groups") => "scheduler",
390        // EKS
391        p if p.starts_with("/clusters") || p == "/tags" || p.starts_with("/tags/") => "eks",
392        // S3 (catch-all — any path starting with / that doesn't match above could be S3)
393        // Don't add S3 here as it would catch everything
394        _ => return None,
395    };
396    Some(service.to_string())
397}