1use anyhow::{Context, Result, anyhow};
7use aws_lambda_events::apigw::{ApiGatewayProxyRequest, ApiGatewayProxyResponse};
8use aws_sdk_apigatewaymanagement::Client as ApiGatewayManagementClient;
9use aws_sdk_apigatewaymanagement::primitives::Blob;
10use aws_sdk_dynamodb::Client as DynamoDbClient;
11use aws_sdk_dynamodb::types::AttributeValue;
12use aws_sdk_eventbridge::Client as EventBridgeClient;
13use http_tunnel_common::ConnectionMetadata;
14use http_tunnel_common::constants::{
15 PENDING_REQUEST_TTL_SECS, POLL_BACKOFF_MULTIPLIER, POLL_INITIAL_INTERVAL_MS,
16 POLL_MAX_INTERVAL_MS, REQUEST_TIMEOUT_SECS,
17};
18use http_tunnel_common::protocol::{HttpRequest, HttpResponse};
19use http_tunnel_common::utils::{calculate_ttl, current_timestamp_millis, current_timestamp_secs};
20use std::time::{Duration, Instant};
21use tracing::{debug, error};
22
23pub mod auth;
24pub mod content_rewrite;
25pub mod error_handling;
26pub mod handlers;
27
28pub fn is_event_driven_enabled() -> bool {
30 std::env::var("USE_EVENT_DRIVEN")
31 .unwrap_or_else(|_| "false".to_string())
32 .to_lowercase()
33 == "true"
34}
35
36pub struct SharedClients {
38 pub dynamodb: DynamoDbClient,
39 pub apigw_management: Option<ApiGatewayManagementClient>,
40 pub eventbridge: EventBridgeClient,
41}
42
43pub fn extract_tunnel_id_from_path(path: &str) -> Result<String> {
46 let parts: Vec<&str> = path.trim_start_matches('/').split('/').collect();
47 if parts.is_empty() || parts[0].is_empty() {
48 return Err(anyhow!("Missing tunnel ID in path"));
49 }
50 let tunnel_id = parts[0].to_string();
51
52 http_tunnel_common::validation::validate_tunnel_id(&tunnel_id)
54 .context("Invalid tunnel ID format")?;
55
56 Ok(tunnel_id)
57}
58
59pub fn strip_tunnel_id_from_path(path: &str) -> String {
63 let parts: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
64 if parts.len() > 1 && !parts[1].is_empty() {
65 format!("/{}", parts[1])
66 } else {
67 "/".to_string()
68 }
69}
70
71pub async fn save_connection_metadata(
73 client: &DynamoDbClient,
74 metadata: &ConnectionMetadata,
75) -> Result<()> {
76 let table_name = std::env::var("CONNECTIONS_TABLE_NAME")
77 .context("CONNECTIONS_TABLE_NAME environment variable not set")?;
78
79 client
80 .put_item()
81 .table_name(&table_name)
82 .item(
83 "connectionId",
84 AttributeValue::S(metadata.connection_id.clone()),
85 )
86 .item("tunnelId", AttributeValue::S(metadata.tunnel_id.clone()))
87 .item("publicUrl", AttributeValue::S(metadata.public_url.clone()))
88 .item(
89 "createdAt",
90 AttributeValue::N(metadata.created_at.to_string()),
91 )
92 .item("ttl", AttributeValue::N(metadata.ttl.to_string()))
93 .send()
94 .await
95 .context("Failed to save connection metadata to DynamoDB")?;
96
97 Ok(())
98}
99
100pub async fn delete_connection(client: &DynamoDbClient, connection_id: &str) -> Result<()> {
102 let table_name = std::env::var("CONNECTIONS_TABLE_NAME")
103 .context("CONNECTIONS_TABLE_NAME environment variable not set")?;
104
105 client
106 .delete_item()
107 .table_name(&table_name)
108 .key("connectionId", AttributeValue::S(connection_id.to_string()))
109 .send()
110 .await
111 .context("Failed to delete connection from DynamoDB")?;
112
113 Ok(())
114}
115
116pub async fn lookup_connection_by_tunnel_id(
118 client: &DynamoDbClient,
119 tunnel_id: &str,
120) -> Result<String> {
121 let table_name = std::env::var("CONNECTIONS_TABLE_NAME")
122 .context("CONNECTIONS_TABLE_NAME environment variable not set")?;
123 let index_name = "tunnel-id-index";
124
125 let result = client
126 .query()
127 .table_name(&table_name)
128 .index_name(index_name)
129 .key_condition_expression("tunnelId = :tunnel_id")
130 .expression_attribute_values(":tunnel_id", AttributeValue::S(tunnel_id.to_string()))
131 .limit(1)
132 .send()
133 .await
134 .context("Failed to query connection by tunnel ID")?;
135
136 let items = result.items.ok_or_else(|| anyhow!("No items returned"))?;
137 let item = items
138 .first()
139 .ok_or_else(|| anyhow!("Connection not found for tunnel ID: {}", tunnel_id))?;
140
141 let connection_id = item
142 .get("connectionId")
143 .and_then(|v| v.as_s().ok())
144 .ok_or_else(|| anyhow!("Missing connectionId in DynamoDB item"))?;
145
146 Ok(connection_id.clone())
147}
148
149pub fn build_http_request(request: &ApiGatewayProxyRequest, request_id: String) -> HttpRequest {
151 let method = request.http_method.to_string();
152
153 let uri = format!("{}{}", request.path.as_deref().unwrap_or("/"), {
154 let params = &request.query_string_parameters;
155 if params.is_empty() {
156 String::new()
157 } else {
158 format!(
159 "?{}",
160 params
161 .iter()
162 .map(|(k, v)| format!("{}={}", k, v))
163 .collect::<Vec<_>>()
164 .join("&")
165 )
166 }
167 });
168
169 let headers = request
170 .headers
171 .iter()
172 .map(|(k, v)| {
173 (
174 k.as_str().to_string(),
175 vec![v.to_str().unwrap_or("").to_string()],
176 )
177 })
178 .collect();
179
180 let body = request
181 .body
182 .as_ref()
183 .map(|b| {
184 if request.is_base64_encoded {
185 b.to_string() } else {
187 http_tunnel_common::encode_body(b.as_bytes())
188 }
189 })
190 .unwrap_or_default();
191
192 HttpRequest {
193 request_id,
194 method,
195 uri,
196 headers,
197 body,
198 timestamp: current_timestamp_millis(),
199 }
200}
201
202pub async fn save_pending_request(
204 client: &DynamoDbClient,
205 request_id: &str,
206 connection_id: &str,
207 api_gateway_request_id: &str,
208) -> Result<()> {
209 let table_name = std::env::var("PENDING_REQUESTS_TABLE_NAME")
210 .context("PENDING_REQUESTS_TABLE_NAME environment variable not set")?;
211 let created_at = current_timestamp_secs();
212 let ttl = calculate_ttl(PENDING_REQUEST_TTL_SECS);
213
214 client
215 .put_item()
216 .table_name(&table_name)
217 .item("requestId", AttributeValue::S(request_id.to_string()))
218 .item("connectionId", AttributeValue::S(connection_id.to_string()))
219 .item(
220 "apiGatewayRequestId",
221 AttributeValue::S(api_gateway_request_id.to_string()),
222 )
223 .item("createdAt", AttributeValue::N(created_at.to_string()))
224 .item("ttl", AttributeValue::N(ttl.to_string()))
225 .item("status", AttributeValue::S("pending".to_string()))
226 .send()
227 .await
228 .context("Failed to save pending request to DynamoDB")?;
229
230 Ok(())
231}
232
233pub async fn send_to_connection(
235 client: &ApiGatewayManagementClient,
236 connection_id: &str,
237 data: &str,
238) -> Result<()> {
239 client
240 .post_to_connection()
241 .connection_id(connection_id)
242 .data(Blob::new(data.as_bytes()))
243 .send()
244 .await
245 .context("Failed to send message to WebSocket connection")?;
246
247 Ok(())
248}
249
250pub async fn wait_for_response(client: &DynamoDbClient, request_id: &str) -> Result<HttpResponse> {
252 if is_event_driven_enabled() {
253 wait_for_response_event_driven(client, request_id).await
254 } else {
255 wait_for_response_polling(client, request_id).await
256 }
257}
258
259async fn check_for_response(
261 client: &DynamoDbClient,
262 table_name: &str,
263 request_id: &str,
264) -> Result<Option<HttpResponse>> {
265 let result = client
266 .get_item()
267 .table_name(table_name)
268 .key("requestId", AttributeValue::S(request_id.to_string()))
269 .send()
270 .await
271 .context("Failed to get pending request from DynamoDB")?;
272
273 if let Some(item) = result.item {
274 let status = item
275 .get("status")
276 .and_then(|v| v.as_s().ok())
277 .ok_or_else(|| anyhow!("Missing status in DynamoDB item"))?;
278
279 if status == "completed" {
280 let response_data = item
282 .get("responseData")
283 .and_then(|v| v.as_s().ok())
284 .ok_or_else(|| anyhow!("Missing responseData in completed request"))?;
285
286 let response: HttpResponse = serde_json::from_str(response_data)
287 .context("Failed to parse response data JSON")?;
288
289 if let Err(e) = client
291 .delete_item()
292 .table_name(table_name)
293 .key("requestId", AttributeValue::S(request_id.to_string()))
294 .send()
295 .await
296 {
297 error!("Failed to clean up pending request: {}", e);
298 }
299
300 return Ok(Some(response));
301 }
302 }
303
304 Ok(None)
305}
306
307async fn wait_for_response_event_driven(
310 client: &DynamoDbClient,
311 request_id: &str,
312) -> Result<HttpResponse> {
313 let table_name = std::env::var("PENDING_REQUESTS_TABLE_NAME")
314 .context("PENDING_REQUESTS_TABLE_NAME environment variable not set")?;
315 let timeout = Duration::from_secs(REQUEST_TIMEOUT_SECS);
316 let start = Instant::now();
317
318 if let Some(response) = check_for_response(client, &table_name, request_id).await? {
324 return Ok(response);
325 }
326
327 let wait_duration = Duration::from_millis(800);
330 tokio::time::sleep(wait_duration).await;
331
332 if let Some(response) = check_for_response(client, &table_name, request_id).await? {
334 return Ok(response);
335 }
336
337 let mut poll_interval = Duration::from_millis(200);
339 loop {
340 if start.elapsed() > timeout {
341 return Err(anyhow!("Request timeout waiting for response"));
342 }
343
344 tokio::time::sleep(poll_interval).await;
345
346 if let Some(response) = check_for_response(client, &table_name, request_id).await? {
347 return Ok(response);
348 }
349
350 poll_interval = Duration::from_millis(500); }
352}
353
354async fn wait_for_response_polling(
356 client: &DynamoDbClient,
357 request_id: &str,
358) -> Result<HttpResponse> {
359 let table_name = std::env::var("PENDING_REQUESTS_TABLE_NAME")
360 .context("PENDING_REQUESTS_TABLE_NAME environment variable not set")?;
361 let timeout = Duration::from_secs(REQUEST_TIMEOUT_SECS);
362 let start = Instant::now();
363
364 let mut poll_interval = Duration::from_millis(POLL_INITIAL_INTERVAL_MS);
366 let max_poll_interval = Duration::from_millis(POLL_MAX_INTERVAL_MS);
367
368 loop {
369 if start.elapsed() > timeout {
370 return Err(anyhow!("Request timeout waiting for response"));
371 }
372
373 let result = client
375 .get_item()
376 .table_name(&table_name)
377 .key("requestId", AttributeValue::S(request_id.to_string()))
378 .send()
379 .await
380 .context("Failed to get pending request from DynamoDB")?;
381
382 if let Some(item) = result.item {
383 let status = item
384 .get("status")
385 .and_then(|v| v.as_s().ok())
386 .ok_or_else(|| anyhow!("Missing status in DynamoDB item"))?;
387
388 if status == "completed" {
389 let response_data = item
391 .get("responseData")
392 .and_then(|v| v.as_s().ok())
393 .ok_or_else(|| anyhow!("Missing responseData in completed request"))?;
394
395 let response: HttpResponse = serde_json::from_str(response_data)
396 .context("Failed to parse response data JSON")?;
397
398 if let Err(e) = client
400 .delete_item()
401 .table_name(&table_name)
402 .key("requestId", AttributeValue::S(request_id.to_string()))
403 .send()
404 .await
405 {
406 error!("Failed to clean up pending request: {}", e);
407 }
408
409 return Ok(response);
410 }
411 }
412
413 tokio::time::sleep(poll_interval).await;
414
415 poll_interval = std::cmp::min(poll_interval * POLL_BACKOFF_MULTIPLIER, max_poll_interval);
417 }
418}
419
420pub fn build_api_gateway_response(response: HttpResponse) -> ApiGatewayProxyResponse {
422 use http::header::{HeaderName, HeaderValue};
423
424 let headers = response
425 .headers
426 .iter()
427 .filter_map(|(k, v)| {
428 v.first().and_then(|val| {
429 HeaderName::from_bytes(k.as_bytes())
430 .ok()
431 .and_then(|name| HeaderValue::from_str(val).ok().map(|value| (name, value)))
432 })
433 })
434 .collect();
435
436 use aws_lambda_events::encodings::Body;
437
438 let body = if !response.body.is_empty() {
439 Some(Body::Text(response.body))
440 } else {
441 None
442 };
443
444 ApiGatewayProxyResponse {
445 status_code: response.status_code as i64,
446 headers,
447 multi_value_headers: Default::default(),
448 body,
449 is_base64_encoded: true,
450 }
451}
452
453pub async fn update_pending_request_with_response(
455 client: &DynamoDbClient,
456 response: &HttpResponse,
457) -> Result<()> {
458 let table_name = std::env::var("PENDING_REQUESTS_TABLE_NAME")
459 .context("PENDING_REQUESTS_TABLE_NAME environment variable not set")?;
460
461 let response_data =
463 serde_json::to_string(response).context("Failed to serialize response to JSON")?;
464
465 client
467 .update_item()
468 .table_name(&table_name)
469 .key("requestId", AttributeValue::S(response.request_id.clone()))
470 .update_expression("SET #status = :status, responseData = :data")
471 .expression_attribute_names("#status", "status")
472 .expression_attribute_values(":status", AttributeValue::S("completed".to_string()))
473 .expression_attribute_values(":data", AttributeValue::S(response_data))
474 .send()
475 .await
476 .context("Failed to update pending request with response")?;
477
478 debug!("Updated pending request: {}", response.request_id);
479
480 Ok(())
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_build_http_request_simple_get() {
489 use http::Method;
490
491 let request = ApiGatewayProxyRequest {
492 http_method: Method::GET,
493 path: Some("/api/users".to_string()),
494 ..Default::default()
495 };
496
497 let http_request = build_http_request(&request, "req_123".to_string());
498
499 assert_eq!(http_request.request_id, "req_123");
500 assert_eq!(http_request.method, "GET");
501 assert_eq!(http_request.uri, "/api/users");
502 assert!(http_request.body.is_empty());
503 }
504
505 #[test]
506 fn test_build_http_request_with_path() {
507 use http::Method;
508
509 let request = ApiGatewayProxyRequest {
510 http_method: Method::GET,
511 path: Some("/api/users".to_string()),
512 ..Default::default()
513 };
514
515 let http_request = build_http_request(&request, "req_123".to_string());
516
517 assert_eq!(http_request.request_id, "req_123");
518 assert_eq!(http_request.method, "GET");
519 assert_eq!(http_request.uri, "/api/users");
520 }
521
522 #[test]
523 fn test_build_http_request_with_body() {
524 use http::Method;
525
526 let request = ApiGatewayProxyRequest {
527 http_method: Method::POST,
528 path: Some("/api/data".to_string()),
529 body: Some("Hello World".to_string()),
530 is_base64_encoded: false,
531 ..Default::default()
532 };
533
534 let http_request = build_http_request(&request, "req_123".to_string());
535
536 assert_eq!(http_request.method, "POST");
537 assert!(!http_request.body.is_empty());
538 }
539
540 #[test]
541 fn test_build_api_gateway_response_success() {
542 use std::collections::HashMap;
543
544 let mut headers = HashMap::new();
545 headers.insert(
546 "content-type".to_string(),
547 vec!["application/json".to_string()],
548 );
549
550 let response = HttpResponse {
551 request_id: "req_123".to_string(),
552 status_code: 200,
553 headers,
554 body: "eyJ0ZXN0IjoidmFsdWUifQ==".to_string(),
555 processing_time_ms: 123,
556 };
557
558 let apigw_response = build_api_gateway_response(response);
559
560 assert_eq!(apigw_response.status_code, 200);
561 assert!(apigw_response.is_base64_encoded);
562 assert!(apigw_response.body.is_some());
563 assert!(!apigw_response.headers.is_empty());
565 }
566
567 #[test]
568 fn test_build_api_gateway_response_empty_body() {
569 use std::collections::HashMap;
570
571 let response = HttpResponse {
572 request_id: "req_123".to_string(),
573 status_code: 204,
574 headers: HashMap::new(),
575 body: String::new(),
576 processing_time_ms: 0,
577 };
578
579 let apigw_response = build_api_gateway_response(response);
580
581 assert_eq!(apigw_response.status_code, 204);
582 assert!(apigw_response.body.is_none());
583 }
584}