use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse};
use http_tunnel_common::constants::MAX_BODY_SIZE_BYTES;
use http_tunnel_common::protocol::Message;
use http_tunnel_common::utils::generate_request_id;
use lambda_runtime::{Error, LambdaEvent};
use tracing::{debug, error, info, warn};
use crate::{
SharedClients, build_api_gateway_response, build_http_request, content_rewrite,
extract_tunnel_id_from_path, lookup_connection_by_tunnel_id, save_pending_request,
send_to_connection, strip_tunnel_id_from_path, wait_for_response,
};
pub async fn handle_forwarding(
event: LambdaEvent<ApiGatewayProxyRequest>,
clients: &SharedClients,
) -> Result<ApiGatewayProxyResponse, Error> {
let mut request = event.payload;
let request_id_context = request.request_context.request_id.clone();
let original_path = request.path.as_deref().unwrap_or("/");
debug!("Processing HTTP request, path: {}", original_path);
let tunnel_id = extract_tunnel_id_from_path(original_path).map_err(|e| {
error!(
"Failed to extract tunnel ID from path {}: {}",
original_path, e
);
"Invalid request path".to_string()
})?;
let actual_path = strip_tunnel_id_from_path(original_path);
debug!(
"Forwarding request for tunnel_id: {} (method: {}, original_path: {}, actual_path: {})",
tunnel_id, request.http_method, original_path, actual_path
);
request.path = Some(actual_path);
if let Some(body) = &request.body {
let body_size = if request.is_base64_encoded {
(body.len() * 3) / 4
} else {
body.len()
};
if body_size > MAX_BODY_SIZE_BYTES {
use aws_lambda_events::encodings::Body;
use http::header::{HeaderName, HeaderValue};
warn!(
"Request body too large: {} bytes (max: {} bytes) for tunnel {}",
body_size, MAX_BODY_SIZE_BYTES, tunnel_id
);
return Ok(ApiGatewayProxyResponse {
status_code: 413,
headers: [
(
HeaderName::from_static("content-type"),
HeaderValue::from_static("text/plain"),
),
(
HeaderName::from_static("x-tunnel-error"),
HeaderValue::from_static("Request Entity Too Large"),
),
]
.into_iter()
.collect(),
multi_value_headers: Default::default(),
body: Some(Body::Text(format!(
"Request body too large: {} bytes (maximum: {} bytes)",
body_size, MAX_BODY_SIZE_BYTES
))),
is_base64_encoded: false,
});
}
}
let connection_id = lookup_connection_by_tunnel_id(&clients.dynamodb, &tunnel_id)
.await
.map_err(|e| {
error!(
"Failed to lookup connection for tunnel_id {}: {}",
tunnel_id, e
);
"Tunnel not found or unavailable".to_string()
})?;
debug!("Found connection: {}", connection_id);
let request_id = generate_request_id();
let http_request = build_http_request(&request, request_id.clone());
let api_gateway_req_id = request_id_context.as_deref().unwrap_or("unknown");
save_pending_request(
&clients.dynamodb,
&request_id,
&connection_id,
api_gateway_req_id,
)
.await
.map_err(|e| {
error!("Failed to save pending request {}: {}", request_id, e);
"Service temporarily unavailable".to_string()
})?;
let message = Message::HttpRequest(http_request);
let message_json = serde_json::to_string(&message).map_err(|e| {
error!("Failed to serialize message: {}", e);
"Service temporarily unavailable".to_string()
})?;
let apigw_management = clients
.apigw_management
.as_ref()
.ok_or("API Gateway Management client not initialized")?;
send_to_connection(apigw_management, &connection_id, &message_json)
.await
.map_err(|e| {
error!(
"Failed to send request {} to connection {}: {}",
request_id, connection_id, e
);
"Tunnel connection unavailable".to_string()
})?;
info!(
"Forwarded request {} to connection {} for tunnel_id {}",
request_id, connection_id, tunnel_id
);
match wait_for_response(&clients.dynamodb, &request_id).await {
Ok(mut response) => {
info!(
"Received response for request {}: status {}",
request_id, response.status_code
);
let content_type = response
.headers
.get("content-type")
.and_then(|v| v.first())
.map(|s| s.as_str())
.unwrap_or("");
let should_rewrite = content_rewrite::should_rewrite_content(content_type);
let (rewritten_body, was_rewritten) = if should_rewrite {
let body_bytes = http_tunnel_common::decode_body(&response.body)
.map_err(|e| format!("Failed to decode response body: {}", e))?;
let body_str = String::from_utf8_lossy(&body_bytes);
content_rewrite::rewrite_response_content(
&body_str,
content_type,
&tunnel_id,
content_rewrite::RewriteStrategy::FullRewrite,
)
.unwrap_or_else(|e| {
warn!("Content rewrite failed: {}, returning original", e);
(body_str.to_string(), false)
})
} else {
debug!("Skipping rewrite for binary content type: {}", content_type);
(String::new(), false)
};
if was_rewritten {
debug!(
"Content rewritten for request {}: {} bytes",
request_id,
rewritten_body.len()
);
response.body = http_tunnel_common::encode_body(rewritten_body.as_bytes());
response.headers.insert(
"content-length".to_string(),
vec![rewritten_body.len().to_string()],
);
response.headers.remove("transfer-encoding");
response.headers.insert(
"x-tunnel-rewrite-applied".to_string(),
vec!["true".to_string()],
);
}
Ok(build_api_gateway_response(response))
}
Err(e) => {
use aws_lambda_events::encodings::Body;
use http::header::{HeaderName, HeaderValue};
error!("Request {} timeout or error: {}", request_id, e);
Ok(ApiGatewayProxyResponse {
status_code: 504,
headers: [
(
HeaderName::from_static("content-type"),
HeaderValue::from_static("text/plain"),
),
(
HeaderName::from_static("x-tunnel-error"),
HeaderValue::from_static("Gateway Timeout"),
),
]
.into_iter()
.collect(),
multi_value_headers: Default::default(),
body: Some(Body::Text(
"Gateway Timeout: No response from agent".to_string(),
)),
is_base64_encoded: false,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_lambda_events::encodings::Body;
use http::header::{HeaderName, HeaderValue};
#[test]
fn test_timeout_response_format() {
let response = ApiGatewayProxyResponse {
status_code: 504,
headers: [(
HeaderName::from_static("content-type"),
HeaderValue::from_static("text/plain"),
)]
.into_iter()
.collect(),
multi_value_headers: Default::default(),
body: Some(Body::Text(
"Gateway Timeout: No response from agent".to_string(),
)),
is_base64_encoded: false,
};
assert_eq!(response.status_code, 504);
assert!(!response.headers.is_empty());
assert!(response.body.is_some());
}
}