use std::sync::Arc;
use axum::body::Body;
use axum::extract::{Path, State};
use axum::http::{HeaderMap, Method, Response, StatusCode, Uri};
use axum::response::IntoResponse;
use bytes::Bytes;
use serde_json::json;
use tracing::{debug, warn};
use awsim_apigateway::ApiGatewayService;
use awsim_core::{RequestContext, ServiceHandler};
#[derive(Clone)]
pub struct ProxyState {
pub apigw: Arc<ApiGatewayService>,
pub lambda: Option<Arc<dyn ServiceHandler>>,
pub default_account_id: String,
pub default_region: String,
}
pub async fn handle_proxy(
State(state): State<ProxyState>,
Path((api_id, stage, path)): Path<(String, String, String)>,
method: Method,
headers: HeaderMap,
uri: Uri,
body: Bytes,
) -> impl IntoResponse {
debug!(
api_id = %api_id,
stage = %stage,
path = %path,
method = %method,
"API Gateway proxy request"
);
let agw_state = state
.apigw
.store()
.get(&state.default_account_id, &state.default_region);
let path_with_slash = format!("/{path}");
let query_string = uri.query().unwrap_or("");
let proxy_result = awsim_apigateway::proxy_request(
&api_id,
&stage,
method.as_str(),
&path_with_slash,
query_string,
&headers,
&body,
&agw_state,
)
.await;
match proxy_result {
None => {
warn!(
api_id = %api_id,
path = %path_with_slash,
method = %method,
"No matching route found in API Gateway"
);
error_response(
StatusCode::NOT_FOUND,
&format!("No route found for {method} {path_with_slash} in API {api_id}"),
)
}
Some(proxy) => {
let lambda_handler = match &state.lambda {
Some(h) => Arc::clone(h),
None => {
warn!("Lambda service not registered — cannot invoke function");
return error_response(
StatusCode::SERVICE_UNAVAILABLE,
"Lambda service not registered",
);
}
};
let function_name = extract_function_name(&proxy.integration_uri);
let ctx = RequestContext {
account_id: state.default_account_id.clone(),
region: state.default_region.clone(),
service: "lambda".to_string(),
access_key: None,
request_id: uuid::Uuid::new_v4().to_string(),
method: method.to_string(),
uri: uri.to_string(),
event_bus: None,
};
let invoke_input = json!({
"FunctionName": function_name,
"InvocationType": "RequestResponse",
"Payload": proxy.event,
});
match lambda_handler.handle("Invoke", invoke_input, &ctx).await {
Ok(result) => lambda_response_to_http(result),
Err(e) => {
warn!(
function = %function_name,
error = %e.message,
"Lambda invocation failed"
);
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("Lambda invocation error: {}", e.message),
)
}
}
}
}
}
fn lambda_response_to_http(result: serde_json::Value) -> Response<Body> {
if let Some(status_code) = result.get("statusCode").and_then(|v| v.as_u64()) {
let status =
StatusCode::from_u16(status_code as u16).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let mut builder = Response::builder().status(status);
if let Some(resp_headers) = result.get("headers").and_then(|v| v.as_object()) {
for (key, value) in resp_headers {
if let Some(v) = value.as_str()
&& let Ok(header_value) = v.parse::<axum::http::HeaderValue>() {
builder = builder.header(key.as_str(), header_value);
}
}
}
let is_base64 = result
.get("isBase64Encoded")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let body_bytes: Bytes = match result.get("body") {
None => Bytes::new(),
Some(serde_json::Value::Null) => Bytes::new(),
Some(serde_json::Value::String(s)) => {
if is_base64 {
use base64::Engine;
base64::engine::general_purpose::STANDARD
.decode(s)
.unwrap_or_else(|_| s.as_bytes().to_vec())
.into()
} else {
Bytes::from(s.clone())
}
}
Some(other) => Bytes::from(other.to_string()),
};
builder.body(Body::from(body_bytes)).unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
})
} else {
let body = serde_json::to_vec(&result).unwrap_or_default();
Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap()
}
}
fn extract_function_name(uri: &str) -> String {
if uri.starts_with("arn:aws:lambda:") {
uri.split(':').next_back().unwrap_or(uri).to_string()
} else {
uri.to_string()
}
}
fn error_response(status: StatusCode, message: &str) -> Response<Body> {
let body = serde_json::to_vec(&json!({
"message": message,
}))
.unwrap_or_default();
Response::builder()
.status(status)
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap()
}