http_tunnel_handler/
lib.rs

1//! Shared utilities for AWS Lambda handlers
2//!
3//! This module provides common functionality used across all Lambda functions including
4//! DynamoDB operations, request/response transformations, and helper functions.
5
6use anyhow::{Context, Result, anyhow};
7use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse};
8use aws_sdk_apigatewaymanagement::Client as ApiGatewayManagementClient;
9use aws_sdk_apigatewaymanagement::primitives::Blob;
10use aws_sdk_dynamodb::Client as DynamoDbClient;
11use aws_sdk_dynamodb::types::AttributeValue;
12use aws_sdk_eventbridge::Client as EventBridgeClient;
13use http_tunnel_common::ConnectionMetadata;
14use http_tunnel_common::constants::{
15    PENDING_REQUEST_TTL_SECS, POLL_BACKOFF_MULTIPLIER, POLL_INITIAL_INTERVAL_MS,
16    POLL_MAX_INTERVAL_MS, REQUEST_TIMEOUT_SECS,
17};
18use http_tunnel_common::protocol::{HttpRequest, HttpResponse};
19use http_tunnel_common::utils::{calculate_ttl, current_timestamp_millis, current_timestamp_secs};
20use std::time::{Duration, Instant};
21use tracing::{debug, error};
22
23pub mod auth;
24pub mod content_rewrite;
25pub mod error_handling;
26pub mod handlers;
27
28/// Check if event-driven response pattern is enabled
29pub fn is_event_driven_enabled() -> bool {
30    std::env::var("USE_EVENT_DRIVEN")
31        .unwrap_or_else(|_| "false".to_string())
32        .to_lowercase()
33        == "true"
34}
35
36/// Shared AWS clients used across all handlers
37pub struct SharedClients {
38    pub dynamodb: DynamoDbClient,
39    pub apigw_management: Option<ApiGatewayManagementClient>,
40    pub eventbridge: EventBridgeClient,
41}
42
43/// Extract tunnel ID from request path (path-based routing)
44/// Example: "/abc123/api/users" -> "abc123"
45pub fn extract_tunnel_id_from_path(path: &str) -> Result<String> {
46    let parts: Vec<&str> = path.trim_start_matches('/').split('/').collect();
47    if parts.is_empty() || parts[0].is_empty() {
48        return Err(anyhow!("Missing tunnel ID in path"));
49    }
50    let tunnel_id = parts[0].to_string();
51
52    // Validate tunnel ID format to prevent injection attacks
53    http_tunnel_common::validation::validate_tunnel_id(&tunnel_id)
54        .context("Invalid tunnel ID format")?;
55
56    Ok(tunnel_id)
57}
58
59/// Strip tunnel ID from path before forwarding to local service
60/// Example: "/abc123/api/users" -> "/api/users"
61/// Example: "/abc123" -> "/"
62pub fn strip_tunnel_id_from_path(path: &str) -> String {
63    let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
64    if parts.len() > 1 && !parts[1].is_empty() {
65        format!("/{}", parts[1])
66    } else {
67        "/".to_string()
68    }
69}
70
71/// Save connection metadata to DynamoDB
72pub async fn save_connection_metadata(
73    client: &DynamoDbClient,
74    metadata: &ConnectionMetadata,
75) -> Result<()> {
76    let table_name = std::env::var("CONNECTIONS_TABLE_NAME")
77        .context("CONNECTIONS_TABLE_NAME environment variable not set")?;
78
79    client
80        .put_item()
81        .table_name(&table_name)
82        .item(
83            "connectionId",
84            AttributeValue::S(metadata.connection_id.clone()),
85        )
86        .item("tunnelId", AttributeValue::S(metadata.tunnel_id.clone()))
87        .item("publicUrl", AttributeValue::S(metadata.public_url.clone()))
88        .item(
89            "createdAt",
90            AttributeValue::N(metadata.created_at.to_string()),
91        )
92        .item("ttl", AttributeValue::N(metadata.ttl.to_string()))
93        .send()
94        .await
95        .context("Failed to save connection metadata to DynamoDB")?;
96
97    Ok(())
98}
99
100/// Delete connection from DynamoDB
101pub async fn delete_connection(client: &DynamoDbClient, connection_id: &str) -> Result<()> {
102    let table_name = std::env::var("CONNECTIONS_TABLE_NAME")
103        .context("CONNECTIONS_TABLE_NAME environment variable not set")?;
104
105    client
106        .delete_item()
107        .table_name(&table_name)
108        .key("connectionId", AttributeValue::S(connection_id.to_string()))
109        .send()
110        .await
111        .context("Failed to delete connection from DynamoDB")?;
112
113    Ok(())
114}
115
116/// Look up connection ID by tunnel ID using GSI (path-based routing)
117pub async fn lookup_connection_by_tunnel_id(
118    client: &DynamoDbClient,
119    tunnel_id: &str,
120) -> Result<String> {
121    let table_name = std::env::var("CONNECTIONS_TABLE_NAME")
122        .context("CONNECTIONS_TABLE_NAME environment variable not set")?;
123    let index_name = "tunnel-id-index";
124
125    let result = client
126        .query()
127        .table_name(&table_name)
128        .index_name(index_name)
129        .key_condition_expression("tunnelId = :tunnel_id")
130        .expression_attribute_values(":tunnel_id", AttributeValue::S(tunnel_id.to_string()))
131        .limit(1)
132        .send()
133        .await
134        .context("Failed to query connection by tunnel ID")?;
135
136    let items = result.items.ok_or_else(|| anyhow!("No items returned"))?;
137    let item = items
138        .first()
139        .ok_or_else(|| anyhow!("Connection not found for tunnel ID: {}", tunnel_id))?;
140
141    let connection_id = item
142        .get("connectionId")
143        .and_then(|v| v.as_s().ok())
144        .ok_or_else(|| anyhow!("Missing connectionId in DynamoDB item"))?;
145
146    Ok(connection_id.clone())
147}
148
149/// Build HttpRequest from API Gateway event
150pub fn build_http_request(request: &ApiGatewayProxyRequest, request_id: String) -> HttpRequest {
151    let method = request.http_method.to_string();
152
153    let uri = format!("{}{}", request.path.as_deref().unwrap_or("/"), {
154        let params = &request.query_string_parameters;
155        if params.is_empty() {
156            String::new()
157        } else {
158            format!(
159                "?{}",
160                params
161                    .iter()
162                    .map(|(k, v)| format!("{}={}", k, v))
163                    .collect::<Vec<_>>()
164                    .join("&")
165            )
166        }
167    });
168
169    let headers = request
170        .headers
171        .iter()
172        .map(|(k, v)| {
173            (
174                k.as_str().to_string(),
175                vec![v.to_str().unwrap_or("").to_string()],
176            )
177        })
178        .collect();
179
180    let body = request
181        .body
182        .as_ref()
183        .map(|b| {
184            if request.is_base64_encoded {
185                b.to_string() // Already base64
186            } else {
187                http_tunnel_common::encode_body(b.as_bytes())
188            }
189        })
190        .unwrap_or_default();
191
192    HttpRequest {
193        request_id,
194        method,
195        uri,
196        headers,
197        body,
198        timestamp: current_timestamp_millis(),
199    }
200}
201
202/// Save pending request to DynamoDB
203pub async fn save_pending_request(
204    client: &DynamoDbClient,
205    request_id: &str,
206    connection_id: &str,
207    api_gateway_request_id: &str,
208) -> Result<()> {
209    let table_name = std::env::var("PENDING_REQUESTS_TABLE_NAME")
210        .context("PENDING_REQUESTS_TABLE_NAME environment variable not set")?;
211    let created_at = current_timestamp_secs();
212    let ttl = calculate_ttl(PENDING_REQUEST_TTL_SECS);
213
214    client
215        .put_item()
216        .table_name(&table_name)
217        .item("requestId", AttributeValue::S(request_id.to_string()))
218        .item("connectionId", AttributeValue::S(connection_id.to_string()))
219        .item(
220            "apiGatewayRequestId",
221            AttributeValue::S(api_gateway_request_id.to_string()),
222        )
223        .item("createdAt", AttributeValue::N(created_at.to_string()))
224        .item("ttl", AttributeValue::N(ttl.to_string()))
225        .item("status", AttributeValue::S("pending".to_string()))
226        .send()
227        .await
228        .context("Failed to save pending request to DynamoDB")?;
229
230    Ok(())
231}
232
233/// Send message to WebSocket connection
234pub async fn send_to_connection(
235    client: &ApiGatewayManagementClient,
236    connection_id: &str,
237    data: &str,
238) -> Result<()> {
239    client
240        .post_to_connection()
241        .connection_id(connection_id)
242        .data(Blob::new(data.as_bytes()))
243        .send()
244        .await
245        .context("Failed to send message to WebSocket connection")?;
246
247    Ok(())
248}
249
250/// Wait for response with event-driven or polling approach based on USE_EVENT_DRIVEN flag
251pub async fn wait_for_response(client: &DynamoDbClient, request_id: &str) -> Result<HttpResponse> {
252    if is_event_driven_enabled() {
253        wait_for_response_event_driven(client, request_id).await
254    } else {
255        wait_for_response_polling(client, request_id).await
256    }
257}
258
259/// Helper function to check for completed response in DynamoDB
260async fn check_for_response(
261    client: &DynamoDbClient,
262    table_name: &str,
263    request_id: &str,
264) -> Result<Option<HttpResponse>> {
265    let result = client
266        .get_item()
267        .table_name(table_name)
268        .key("requestId", AttributeValue::S(request_id.to_string()))
269        .send()
270        .await
271        .context("Failed to get pending request from DynamoDB")?;
272
273    if let Some(item) = result.item {
274        let status = item
275            .get("status")
276            .and_then(|v| v.as_s().ok())
277            .ok_or_else(|| anyhow!("Missing status in DynamoDB item"))?;
278
279        if status == "completed" {
280            // Extract response data
281            let response_data = item
282                .get("responseData")
283                .and_then(|v| v.as_s().ok())
284                .ok_or_else(|| anyhow!("Missing responseData in completed request"))?;
285
286            let response: HttpResponse = serde_json::from_str(response_data)
287                .context("Failed to parse response data JSON")?;
288
289            // Clean up pending request
290            if let Err(e) = client
291                .delete_item()
292                .table_name(table_name)
293                .key("requestId", AttributeValue::S(request_id.to_string()))
294                .send()
295                .await
296            {
297                error!("Failed to clean up pending request: {}", e);
298            }
299
300            return Ok(Some(response));
301        }
302    }
303
304    Ok(None)
305}
306
307/// Event-driven approach: Check DynamoDB immediately, sleep once, then check again
308/// This dramatically reduces wasted polling when combined with DynamoDB Streams
309async fn wait_for_response_event_driven(
310    client: &DynamoDbClient,
311    request_id: &str,
312) -> Result<HttpResponse> {
313    let table_name = std::env::var("PENDING_REQUESTS_TABLE_NAME")
314        .context("PENDING_REQUESTS_TABLE_NAME environment variable not set")?;
315    let timeout = Duration::from_secs(REQUEST_TIMEOUT_SECS);
316    let start = Instant::now();
317
318    // With EventBridge notifications from DynamoDB Streams,
319    // responses should be ready almost immediately
320    // We use a simplified check pattern: immediate check, wait, final check
321
322    // First check (might already be ready)
323    if let Some(response) = check_for_response(client, &table_name, request_id).await? {
324        return Ok(response);
325    }
326
327    // DynamoDB Stream + EventBridge takes ~100-500ms to process
328    // Sleep for most of the remaining time
329    let wait_duration = Duration::from_millis(800);
330    tokio::time::sleep(wait_duration).await;
331
332    // Second check
333    if let Some(response) = check_for_response(client, &table_name, request_id).await? {
334        return Ok(response);
335    }
336
337    // Final polling loop for any edge cases (much shorter than before)
338    let mut poll_interval = Duration::from_millis(200);
339    loop {
340        if start.elapsed() > timeout {
341            return Err(anyhow!("Request timeout waiting for response"));
342        }
343
344        tokio::time::sleep(poll_interval).await;
345
346        if let Some(response) = check_for_response(client, &table_name, request_id).await? {
347            return Ok(response);
348        }
349
350        poll_interval = Duration::from_millis(500); // Fixed 500ms for final polls
351    }
352}
353
354/// Original polling approach with exponential backoff
355async fn wait_for_response_polling(
356    client: &DynamoDbClient,
357    request_id: &str,
358) -> Result<HttpResponse> {
359    let table_name = std::env::var("PENDING_REQUESTS_TABLE_NAME")
360        .context("PENDING_REQUESTS_TABLE_NAME environment variable not set")?;
361    let timeout = Duration::from_secs(REQUEST_TIMEOUT_SECS);
362    let start = Instant::now();
363
364    // Start with initial poll interval, increase to max with backoff
365    let mut poll_interval = Duration::from_millis(POLL_INITIAL_INTERVAL_MS);
366    let max_poll_interval = Duration::from_millis(POLL_MAX_INTERVAL_MS);
367
368    loop {
369        if start.elapsed() > timeout {
370            return Err(anyhow!("Request timeout waiting for response"));
371        }
372
373        // Query DynamoDB for response
374        let result = client
375            .get_item()
376            .table_name(&table_name)
377            .key("requestId", AttributeValue::S(request_id.to_string()))
378            .send()
379            .await
380            .context("Failed to get pending request from DynamoDB")?;
381
382        if let Some(item) = result.item {
383            let status = item
384                .get("status")
385                .and_then(|v| v.as_s().ok())
386                .ok_or_else(|| anyhow!("Missing status in DynamoDB item"))?;
387
388            if status == "completed" {
389                // Extract response data
390                let response_data = item
391                    .get("responseData")
392                    .and_then(|v| v.as_s().ok())
393                    .ok_or_else(|| anyhow!("Missing responseData in completed request"))?;
394
395                let response: HttpResponse = serde_json::from_str(response_data)
396                    .context("Failed to parse response data JSON")?;
397
398                // Clean up pending request
399                if let Err(e) = client
400                    .delete_item()
401                    .table_name(&table_name)
402                    .key("requestId", AttributeValue::S(request_id.to_string()))
403                    .send()
404                    .await
405                {
406                    error!("Failed to clean up pending request: {}", e);
407                }
408
409                return Ok(response);
410            }
411        }
412
413        tokio::time::sleep(poll_interval).await;
414
415        // Exponential backoff with max limit
416        poll_interval = std::cmp::min(poll_interval * POLL_BACKOFF_MULTIPLIER, max_poll_interval);
417    }
418}
419
420/// Convert HttpResponse to API Gateway response
421pub fn build_api_gateway_response(response: HttpResponse) -> ApiGatewayProxyResponse {
422    use http::header::{HeaderName, HeaderValue};
423
424    let headers = response
425        .headers
426        .iter()
427        .filter_map(|(k, v)| {
428            v.first().and_then(|val| {
429                HeaderName::from_bytes(k.as_bytes())
430                    .ok()
431                    .and_then(|name| HeaderValue::from_str(val).ok().map(|value| (name, value)))
432            })
433        })
434        .collect();
435
436    use aws_lambda_events::encodings::Body;
437
438    let body = if !response.body.is_empty() {
439        Some(Body::Text(response.body))
440    } else {
441        None
442    };
443
444    ApiGatewayProxyResponse {
445        status_code: response.status_code as i64,
446        headers,
447        multi_value_headers: Default::default(),
448        body,
449        is_base64_encoded: true,
450    }
451}
452
453/// Update pending request with response data
454pub async fn update_pending_request_with_response(
455    client: &DynamoDbClient,
456    response: &HttpResponse,
457) -> Result<()> {
458    let table_name = std::env::var("PENDING_REQUESTS_TABLE_NAME")
459        .context("PENDING_REQUESTS_TABLE_NAME environment variable not set")?;
460
461    // Serialize response to JSON
462    let response_data =
463        serde_json::to_string(response).context("Failed to serialize response to JSON")?;
464
465    // Update pending request with response data
466    client
467        .update_item()
468        .table_name(&table_name)
469        .key("requestId", AttributeValue::S(response.request_id.clone()))
470        .update_expression("SET #status = :status, responseData = :data")
471        .expression_attribute_names("#status", "status")
472        .expression_attribute_values(":status", AttributeValue::S("completed".to_string()))
473        .expression_attribute_values(":data", AttributeValue::S(response_data))
474        .send()
475        .await
476        .context("Failed to update pending request with response")?;
477
478    debug!("Updated pending request: {}", response.request_id);
479
480    Ok(())
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_build_http_request_simple_get() {
489        use http::Method;
490
491        let request = ApiGatewayProxyRequest {
492            http_method: Method::GET,
493            path: Some("/api/users".to_string()),
494            ..Default::default()
495        };
496
497        let http_request = build_http_request(&request, "req_123".to_string());
498
499        assert_eq!(http_request.request_id, "req_123");
500        assert_eq!(http_request.method, "GET");
501        assert_eq!(http_request.uri, "/api/users");
502        assert!(http_request.body.is_empty());
503    }
504
505    #[test]
506    fn test_build_http_request_with_path() {
507        use http::Method;
508
509        let request = ApiGatewayProxyRequest {
510            http_method: Method::GET,
511            path: Some("/api/users".to_string()),
512            ..Default::default()
513        };
514
515        let http_request = build_http_request(&request, "req_123".to_string());
516
517        assert_eq!(http_request.request_id, "req_123");
518        assert_eq!(http_request.method, "GET");
519        assert_eq!(http_request.uri, "/api/users");
520    }
521
522    #[test]
523    fn test_build_http_request_with_body() {
524        use http::Method;
525
526        let request = ApiGatewayProxyRequest {
527            http_method: Method::POST,
528            path: Some("/api/data".to_string()),
529            body: Some("Hello World".to_string()),
530            is_base64_encoded: false,
531            ..Default::default()
532        };
533
534        let http_request = build_http_request(&request, "req_123".to_string());
535
536        assert_eq!(http_request.method, "POST");
537        assert!(!http_request.body.is_empty());
538    }
539
540    #[test]
541    fn test_build_api_gateway_response_success() {
542        use std::collections::HashMap;
543
544        let mut headers = HashMap::new();
545        headers.insert(
546            "content-type".to_string(),
547            vec!["application/json".to_string()],
548        );
549
550        let response = HttpResponse {
551            request_id: "req_123".to_string(),
552            status_code: 200,
553            headers,
554            body: "eyJ0ZXN0IjoidmFsdWUifQ==".to_string(),
555            processing_time_ms: 123,
556        };
557
558        let apigw_response = build_api_gateway_response(response);
559
560        assert_eq!(apigw_response.status_code, 200);
561        assert!(apigw_response.is_base64_encoded);
562        assert!(apigw_response.body.is_some());
563        // Check header exists (actual value checking would require http types)
564        assert!(!apigw_response.headers.is_empty());
565    }
566
567    #[test]
568    fn test_build_api_gateway_response_empty_body() {
569        use std::collections::HashMap;
570
571        let response = HttpResponse {
572            request_id: "req_123".to_string(),
573            status_code: 204,
574            headers: HashMap::new(),
575            body: String::new(),
576            processing_time_ms: 0,
577        };
578
579        let apigw_response = build_api_gateway_response(response);
580
581        assert_eq!(apigw_response.status_code, 204);
582        assert!(apigw_response.body.is_none());
583    }
584}