http_tunnel_handler/handlers/
forwarding.rs

1//! ForwardingHandler - Handles HTTP API requests
2//!
3//! This module receives public HTTP requests via API Gateway HTTP API,
4//! looks up the connection by subdomain, forwards the request to the agent via WebSocket,
5//! and polls for the response. If no response is received within the timeout,
6//! it returns a 504 Gateway Timeout.
7
8use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse};
9use http_tunnel_common::constants::MAX_BODY_SIZE_BYTES;
10use http_tunnel_common::protocol::Message;
11use http_tunnel_common::utils::generate_request_id;
12use lambda_runtime::{Error, LambdaEvent};
13use tracing::{debug, error, info, warn};
14
15use crate::{
16    SharedClients, build_api_gateway_response, build_http_request, content_rewrite,
17    extract_tunnel_id_from_path, lookup_connection_by_tunnel_id, save_pending_request,
18    send_to_connection, strip_tunnel_id_from_path, wait_for_response,
19};
20
21/// Handler for HTTP API requests
22pub async fn handle_forwarding(
23    event: LambdaEvent<ApiGatewayProxyRequest>,
24    clients: &SharedClients,
25) -> Result<ApiGatewayProxyResponse, Error> {
26    let mut request = event.payload;
27    let request_id_context = request.request_context.request_id.clone();
28
29    // Extract tunnel ID from path (path-based routing)
30    // HTTP API v2.0 puts path in request.path (stage is stripped by API Gateway for payload format 2.0)
31    let original_path = request.path.as_deref().unwrap_or("/");
32
33    debug!("Processing HTTP request, path: {}", original_path);
34
35    let tunnel_id = extract_tunnel_id_from_path(original_path).map_err(|e| {
36        error!(
37            "Failed to extract tunnel ID from path {}: {}",
38            original_path, e
39        );
40        // Sanitized error - don't leak internal details
41        "Invalid request path".to_string()
42    })?;
43
44    // Strip tunnel ID from path before forwarding to local service
45    let actual_path = strip_tunnel_id_from_path(original_path);
46
47    debug!(
48        "Forwarding request for tunnel_id: {} (method: {}, original_path: {}, actual_path: {})",
49        tunnel_id, request.http_method, original_path, actual_path
50    );
51
52    // Update request path to stripped version
53    request.path = Some(actual_path);
54
55    // Enforce request size limits
56    if let Some(body) = &request.body {
57        let body_size = if request.is_base64_encoded {
58            // Estimate decoded size (base64 is ~33% larger than binary)
59            (body.len() * 3) / 4
60        } else {
61            body.len()
62        };
63
64        if body_size > MAX_BODY_SIZE_BYTES {
65            use aws_lambda_events::encodings::Body;
66            use http::header::{HeaderName, HeaderValue};
67
68            warn!(
69                "Request body too large: {} bytes (max: {} bytes) for tunnel {}",
70                body_size, MAX_BODY_SIZE_BYTES, tunnel_id
71            );
72
73            return Ok(ApiGatewayProxyResponse {
74                status_code: 413,
75                headers: [
76                    (
77                        HeaderName::from_static("content-type"),
78                        HeaderValue::from_static("text/plain"),
79                    ),
80                    (
81                        HeaderName::from_static("x-tunnel-error"),
82                        HeaderValue::from_static("Request Entity Too Large"),
83                    ),
84                ]
85                .into_iter()
86                .collect(),
87                multi_value_headers: Default::default(),
88                body: Some(Body::Text(format!(
89                    "Request body too large: {} bytes (maximum: {} bytes)",
90                    body_size, MAX_BODY_SIZE_BYTES
91                ))),
92                is_base64_encoded: false,
93            });
94        }
95    }
96
97    // Look up connection ID by tunnel ID
98    let connection_id = lookup_connection_by_tunnel_id(&clients.dynamodb, &tunnel_id)
99        .await
100        .map_err(|e| {
101            error!(
102                "Failed to lookup connection for tunnel_id {}: {}",
103                tunnel_id, e
104            );
105            // Sanitized error - don't leak internal details
106            "Tunnel not found or unavailable".to_string()
107        })?;
108
109    debug!("Found connection: {}", connection_id);
110
111    // Generate request ID
112    let request_id = generate_request_id();
113
114    // Build HttpRequest payload
115    let http_request = build_http_request(&request, request_id.clone());
116
117    // Store pending request in DynamoDB for response correlation
118    let api_gateway_req_id = request_id_context.as_deref().unwrap_or("unknown");
119    save_pending_request(
120        &clients.dynamodb,
121        &request_id,
122        &connection_id,
123        api_gateway_req_id,
124    )
125    .await
126    .map_err(|e| {
127        error!("Failed to save pending request {}: {}", request_id, e);
128        // Sanitized error - don't leak internal details
129        "Service temporarily unavailable".to_string()
130    })?;
131
132    // Forward request to agent via WebSocket
133    let message = Message::HttpRequest(http_request);
134    let message_json = serde_json::to_string(&message).map_err(|e| {
135        error!("Failed to serialize message: {}", e);
136        // Sanitized error - don't leak internal details
137        "Service temporarily unavailable".to_string()
138    })?;
139
140    let apigw_management = clients
141        .apigw_management
142        .as_ref()
143        .ok_or("API Gateway Management client not initialized")?;
144
145    send_to_connection(apigw_management, &connection_id, &message_json)
146        .await
147        .map_err(|e| {
148            error!(
149                "Failed to send request {} to connection {}: {}",
150                request_id, connection_id, e
151            );
152            // Sanitized error - don't leak internal details
153            "Tunnel connection unavailable".to_string()
154        })?;
155
156    info!(
157        "Forwarded request {} to connection {} for tunnel_id {}",
158        request_id, connection_id, tunnel_id
159    );
160
161    // Poll for response with timeout
162    match wait_for_response(&clients.dynamodb, &request_id).await {
163        Ok(mut response) => {
164            info!(
165                "Received response for request {}: status {}",
166                request_id, response.status_code
167            );
168
169            // Apply content rewriting if applicable
170            let content_type = response
171                .headers
172                .get("content-type")
173                .and_then(|v| v.first())
174                .map(|s| s.as_str())
175                .unwrap_or("");
176
177            // Only decode and rewrite if content type needs rewriting (performance optimization)
178            let should_rewrite = content_rewrite::should_rewrite_content(content_type);
179
180            let (rewritten_body, was_rewritten) = if should_rewrite {
181                // Decode body for rewriting
182                let body_bytes = http_tunnel_common::decode_body(&response.body)
183                    .map_err(|e| format!("Failed to decode response body: {}", e))?;
184                let body_str = String::from_utf8_lossy(&body_bytes);
185
186                // Rewrite content (default strategy: FullRewrite)
187                content_rewrite::rewrite_response_content(
188                    &body_str,
189                    content_type,
190                    &tunnel_id,
191                    content_rewrite::RewriteStrategy::FullRewrite,
192                )
193                .unwrap_or_else(|e| {
194                    warn!("Content rewrite failed: {}, returning original", e);
195                    (body_str.to_string(), false)
196                })
197            } else {
198                // Skip decoding for binary content (images, videos, etc.)
199                debug!("Skipping rewrite for binary content type: {}", content_type);
200                (String::new(), false)
201            };
202
203            if was_rewritten {
204                debug!(
205                    "Content rewritten for request {}: {} bytes",
206                    request_id,
207                    rewritten_body.len()
208                );
209
210                // Re-encode the rewritten body
211                response.body = http_tunnel_common::encode_body(rewritten_body.as_bytes());
212
213                // Update Content-Length header
214                response.headers.insert(
215                    "content-length".to_string(),
216                    vec![rewritten_body.len().to_string()],
217                );
218
219                // Remove Transfer-Encoding header if present (we're not chunking)
220                response.headers.remove("transfer-encoding");
221
222                // Add debug header to indicate rewriting was applied
223                response.headers.insert(
224                    "x-tunnel-rewrite-applied".to_string(),
225                    vec!["true".to_string()],
226                );
227            }
228
229            // Convert HttpResponse to API Gateway response
230            Ok(build_api_gateway_response(response))
231        }
232        Err(e) => {
233            use aws_lambda_events::encodings::Body;
234            use http::header::{HeaderName, HeaderValue};
235
236            error!("Request {} timeout or error: {}", request_id, e);
237            // Return 504 Gateway Timeout
238            Ok(ApiGatewayProxyResponse {
239                status_code: 504,
240                headers: [
241                    (
242                        HeaderName::from_static("content-type"),
243                        HeaderValue::from_static("text/plain"),
244                    ),
245                    (
246                        HeaderName::from_static("x-tunnel-error"),
247                        HeaderValue::from_static("Gateway Timeout"),
248                    ),
249                ]
250                .into_iter()
251                .collect(),
252                multi_value_headers: Default::default(),
253                body: Some(Body::Text(
254                    "Gateway Timeout: No response from agent".to_string(),
255                )),
256                is_base64_encoded: false,
257            })
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use aws_lambda_events::encodings::Body;
266    use http::header::{HeaderName, HeaderValue};
267
268    #[test]
269    fn test_timeout_response_format() {
270        let response = ApiGatewayProxyResponse {
271            status_code: 504,
272            headers: [(
273                HeaderName::from_static("content-type"),
274                HeaderValue::from_static("text/plain"),
275            )]
276            .into_iter()
277            .collect(),
278            multi_value_headers: Default::default(),
279            body: Some(Body::Text(
280                "Gateway Timeout: No response from agent".to_string(),
281            )),
282            is_base64_encoded: false,
283        };
284
285        assert_eq!(response.status_code, 504);
286        assert!(!response.headers.is_empty());
287        assert!(response.body.is_some());
288    }
289}