1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use axum::body::Body;
6use axum::extract::State;
7use axum::http::{HeaderMap, Method, Response, StatusCode, Uri};
8use bytes::Bytes;
9use tracing::{debug, info, warn};
10
11use crate::auth;
12use crate::authz::AuthzEngine;
13use crate::error::AwsError;
14use crate::events::EventBus;
15use crate::protocol::{self, Protocol, RouteDefinition};
16use crate::ServiceHandler;
17
18#[derive(Clone)]
20pub struct AppState {
21 pub services: Arc<HashMap<String, Arc<dyn ServiceHandler>>>,
23 pub routes: Arc<HashMap<String, Vec<RouteDefinition>>>,
25 pub default_region: String,
27 pub default_account_id: String,
29 pub event_bus: EventBus,
31 pub request_count: Arc<AtomicU64>,
33 pub start_time: std::time::Instant,
35 pub authz: Arc<AuthzEngine>,
37}
38
39impl AppState {
40 pub fn new(default_region: String, default_account_id: String) -> Self {
41 Self {
42 services: Arc::new(HashMap::new()),
43 routes: Arc::new(HashMap::new()),
44 default_region,
45 default_account_id,
46 event_bus: EventBus::new(),
47 request_count: Arc::new(AtomicU64::new(0)),
48 start_time: std::time::Instant::now(),
49 authz: Arc::new(AuthzEngine::from_env()),
50 }
51 }
52
53 pub fn register(&mut self, handler: Arc<dyn ServiceHandler>, routes: Vec<RouteDefinition>) {
55 let signing_name = handler.signing_name().to_string();
56 let service_name = handler.service_name().to_string();
57
58 info!(
59 service = %service_name,
60 signing_name = %signing_name,
61 protocol = ?handler.protocol(),
62 routes = routes.len(),
63 "Registered service"
64 );
65
66 Arc::get_mut(&mut self.services)
67 .unwrap()
68 .insert(signing_name.clone(), handler);
69
70 if !routes.is_empty() {
71 Arc::get_mut(&mut self.routes)
72 .unwrap()
73 .insert(signing_name, routes);
74 }
75 }
76}
77
78pub async fn handle_request(
80 State(state): State<AppState>,
81 method: Method,
82 uri: Uri,
83 headers: HeaderMap,
84 body: Bytes,
85) -> Response<Body> {
86 state.request_count.fetch_add(1, Ordering::Relaxed);
87
88 let request_id = uuid::Uuid::new_v4().to_string();
89
90 debug!(
91 method = %method,
92 uri = %uri,
93 request_id = %request_id,
94 "Incoming request"
95 );
96
97 match process_request(&state, &method, &uri, &headers, &body, &request_id).await {
98 Ok((status, mut resp_headers, resp_body)) => {
99 let mut builder = Response::builder().status(status);
100 for (key, value) in resp_headers.drain() {
101 if let Some(key) = key {
102 builder = builder.header(key, value);
103 }
104 }
105 builder.body(Body::from(resp_body)).unwrap()
106 }
107 Err((protocol, error)) => {
108 warn!(
109 error_code = %error.code,
110 error_message = %error.message,
111 request_id = %request_id,
112 "Request failed"
113 );
114 let (status, mut resp_headers, resp_body) =
115 protocol::serialize_error(protocol, &error, &request_id);
116 let mut builder = Response::builder().status(status);
117 for (key, value) in resp_headers.drain() {
118 if let Some(key) = key {
119 builder = builder.header(key, value);
120 }
121 }
122 builder.body(Body::from(resp_body)).unwrap()
123 }
124 }
125}
126
127async fn process_request(
128 state: &AppState,
129 method: &Method,
130 uri: &Uri,
131 headers: &HeaderMap,
132 body: &Bytes,
133 request_id: &str,
134) -> Result<(StatusCode, HeaderMap, Bytes), (Protocol, AwsError)> {
135 let (service_name, region, account_id, access_key) =
137 extract_service_info(state, headers, uri);
138
139 let handler = state
141 .services
142 .get(&service_name)
143 .ok_or_else(|| {
144 let protocol = protocol::detect_protocol(headers, body).unwrap_or(Protocol::RestJson1);
145 (
146 protocol,
147 AwsError::bad_request(
148 "UnknownService",
149 format!("Service '{service_name}' is not registered"),
150 ),
151 )
152 })?;
153
154 let protocol = handler.protocol();
155
156 let detected = protocol::detect_protocol(headers, body).unwrap_or(protocol);
158
159 let empty_routes = Vec::new();
161 let routes = state.routes.get(&service_name).unwrap_or(&empty_routes);
162
163 let parsed = protocol::parse_request(detected, method, uri, headers, body, routes)
165 .map_err(|e| (detected, e))?;
166
167 debug!(
168 service = %service_name,
169 operation = %parsed.operation,
170 request_id = %request_id,
171 "Dispatching operation"
172 );
173
174 let ctx = crate::router::RequestContext {
176 account_id,
177 region,
178 service: service_name.clone(),
179 access_key,
180 request_id: request_id.to_string(),
181 method: method.to_string(),
182 uri: uri.to_string(),
183 event_bus: Some(state.event_bus.clone()),
184 };
185
186 if let (Some(action), Some(resource)) = (
188 handler.iam_action(&parsed.operation),
189 handler.iam_resource(&parsed.operation, &parsed.input, &ctx),
190 ) {
191 state
192 .authz
193 .check(&ctx, &action, &resource)
194 .map_err(|e| (detected, e))?;
195 } else {
196 debug!(
197 service = %service_name,
198 operation = %parsed.operation,
199 "Skipping IAM check — handler does not declare action/resource"
200 );
201 }
202
203 let result = handler
205 .handle(&parsed.operation, parsed.input, &ctx)
206 .await
207 .map_err(|e| (detected, e))?;
208
209 Ok(protocol::serialize_response(
214 detected,
215 &parsed.operation,
216 &result,
217 request_id,
218 ))
219}
220
221fn extract_service_info(
223 state: &AppState,
224 headers: &HeaderMap,
225 uri: &Uri,
226) -> (String, String, String, Option<String>) {
227 if let Some(auth_header) = headers.get("authorization").and_then(|v| v.to_str().ok()) {
229 if let Some(creds) = auth::parse_authorization(auth_header) {
230 return (
231 creds.service,
232 creds.region,
233 state.default_account_id.clone(),
234 Some(creds.access_key),
235 );
236 }
237 }
238
239 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
241 if let Some(service) = resolve_service_from_target(target) {
242 return (
243 service,
244 state.default_region.clone(),
245 state.default_account_id.clone(),
246 None,
247 );
248 }
249 }
250
251 if let Some(host) = headers.get("host").and_then(|v| v.to_str().ok()) {
253 if let Some(service) = extract_service_from_host(host) {
254 return (
255 service,
256 state.default_region.clone(),
257 state.default_account_id.clone(),
258 None,
259 );
260 }
261 }
262
263 if let Some(query) = uri.query() {
265 if query.contains("X-Amz-Credential") {
266 if let Some(cred_start) = query.find("X-Amz-Credential=") {
267 let cred_val = &query[cred_start + 17..];
268 let cred_end = cred_val.find('&').unwrap_or(cred_val.len());
269 let cred = &cred_val[..cred_end];
270 let cred_decoded = cred.replace("%2F", "/");
271 let parts: Vec<&str> = cred_decoded.split('/').collect();
272 if parts.len() >= 4 {
273 return (
274 parts[3].to_string(),
275 parts[2].to_string(),
276 state.default_account_id.clone(),
277 Some(parts[0].to_string()),
278 );
279 }
280 }
281 }
282 }
283
284 let path = uri.path();
286 if let Some(service) = resolve_service_from_path(path) {
287 return (
288 service,
289 state.default_region.clone(),
290 state.default_account_id.clone(),
291 None,
292 );
293 }
294
295 warn!(
297 auth = ?headers.get("authorization").map(|v| v.to_str().unwrap_or("<non-utf8>")),
298 target = ?headers.get("x-amz-target").map(|v| v.to_str().unwrap_or("<non-utf8>")),
299 host = ?headers.get("host").map(|v| v.to_str().unwrap_or("<non-utf8>")),
300 path = %path,
301 "Could not determine service — falling back to 'unknown'"
302 );
303 (
304 "unknown".to_string(),
305 state.default_region.clone(),
306 state.default_account_id.clone(),
307 None,
308 )
309}
310
311fn resolve_service_from_target(target: &str) -> Option<String> {
313 let prefix = target.split('.').next()?;
314 let service = match prefix {
315 p if p.starts_with("DynamoDB") => "dynamodb",
317 p if p.starts_with("AmazonSQS") => "sqs",
318 p if p.starts_with("AmazonSNS") => "sns",
319 p if p.starts_with("TrentService") => "kms",
320 p if p.starts_with("secretsmanager") => "secretsmanager",
321 p if p.starts_with("AmazonSSM") => "ssm",
322 p if p.starts_with("Logs") => "logs",
323 p if p.starts_with("Kinesis") => "kinesis",
324 p if p.starts_with("AWSStepFunctions") => "states",
325 p if p.starts_with("AWSEvents") => "events",
326 p if p.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
328 p if p.starts_with("AWSCognitoIdentityService") => "cognito-identity",
329 p if p.starts_with("AmazonEC2ContainerServiceV2") => "ecs",
331 p if p.starts_with("AmazonEC2ContainerRegistry") => "ecr",
332 p if p.starts_with("AmazonAthena") => "athena",
334 p if p.starts_with("AWSGlue") => "glue",
335 p if p.starts_with("CertificateManager") => "acm",
337 p if p.starts_with("AWSWAF") => "wafv2",
338 p if p.starts_with("Comprehend") => "comprehend",
339 p if p.starts_with("kendra") => "kendra",
340 p if p.starts_with("AWSOrganizationsV") => "organizations",
342 p if p.starts_with("CloudTrail_") => "cloudtrail",
343 p if p.starts_with("Firehose_") => "firehose",
345 _ => return None,
346 };
347 Some(service.to_string())
348}
349
350fn extract_service_from_host(host: &str) -> Option<String> {
354 let host = host.split(':').next().unwrap_or(host);
356 let parts: Vec<&str> = host.split('.').collect();
357 if parts.len() >= 2 {
358 let first = parts[0];
359 if !first.contains('-') || ["s3", "sqs", "sns", "dynamodb", "lambda", "iam", "sts", "kms", "logs", "events", "states", "ssm", "secretsmanager", "execute-api", "cognito-idp", "cognito-identity"].contains(&first) {
361 return Some(first.to_string());
362 }
363 }
364 None
365}
366
367fn resolve_service_from_path(path: &str) -> Option<String> {
371 let service = match path {
372 p if p.starts_with("/2015-03-31/functions") || p.starts_with("/2018-10-31/layers") => "lambda",
374 p if p.starts_with("/v2/apis") => "execute-api",
376 p if p.starts_with("/v2/email") => "ses",
378 p if p.starts_with("/2013-04-01/hostedzone") || p.starts_with("/2013-04-01/healthcheck") || p.starts_with("/2013-04-01/tags") => "route53",
380 p if p.starts_with("/2020-05-31/distribution") || p.starts_with("/2020-05-31/origin-access-control") || p.starts_with("/2020-05-31/cache-policy") || p.starts_with("/2020-05-31/tagging") => "cloudfront",
382 p if p.starts_with("/v1/apis") => "appsync",
384 p if p.starts_with("/foundation-models") || p.starts_with("/guardrails") || p.starts_with("/model-customization") => "bedrock",
386 p if p.starts_with("/model/") => "bedrock-runtime",
388 p if p.starts_with("/schedules") || p.starts_with("/schedule-groups") => "scheduler",
390 p if p.starts_with("/clusters") || p == "/tags" || p.starts_with("/tags/") => "eks",
392 _ => return None,
395 };
396 Some(service.to_string())
397}