http_tunnel_handler/handlers/
response.rs1use aws_lambda_events::apigw::ApiGatewayProxyResponse;
8use aws_sdk_dynamodb::Client as DynamoDbClient;
9use aws_sdk_dynamodb::types::AttributeValue;
10use http_tunnel_common::encode_body;
11use http_tunnel_common::protocol::{ErrorCode, HttpResponse, Message};
12use lambda_runtime::{Error, LambdaEvent};
13use serde::{Deserialize, Serialize};
14use tracing::{debug, error, info, warn};
15
16use crate::{SharedClients, update_pending_request_with_response};
17use aws_sdk_apigatewaymanagement::primitives::Blob;
18
19#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
22#[serde(rename_all = "camelCase")]
23pub struct WebSocketMessageEvent {
24 pub request_context: WebSocketMessageRequestContext,
25 pub body: Option<String>,
26 pub is_base64_encoded: Option<bool>,
27}
28
29#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
30#[serde(rename_all = "camelCase")]
31pub struct WebSocketMessageRequestContext {
32 pub route_key: String,
33 #[serde(default)]
34 pub event_type: Option<String>,
35 pub connection_id: String,
36 pub request_id: String,
37 pub domain_name: Option<String>,
38 pub stage: Option<String>,
39 pub api_id: Option<String>,
40 #[serde(default)]
41 pub connected_at: Option<i64>,
42}
43
44pub async fn handle_response(
46 event: LambdaEvent<WebSocketMessageEvent>,
47 clients: &SharedClients,
48) -> Result<ApiGatewayProxyResponse, Error> {
49 let body = event.payload.body.ok_or("Missing message body")?;
50
51 debug!("Received message from agent: {}", body);
52
53 let message: Message = serde_json::from_str(&body).map_err(|e| {
55 error!("Failed to parse message: {}", e);
56 format!("Invalid message format: {}", e)
57 })?;
58
59 let connection_id = &event.payload.request_context.connection_id;
60
61 match message {
62 Message::Ready => {
63 info!("Received Ready message from agent, sending ConnectionEstablished");
64 handle_ready_message(&clients.dynamodb, &clients.apigw_management, connection_id)
65 .await?;
66 }
67 Message::HttpResponse(response) => {
68 info!(
69 "Received HTTP response for request {}: status {}",
70 response.request_id, response.status_code
71 );
72 handle_http_response(&clients.dynamodb, response).await?;
73 }
74 Message::Ping => {
75 debug!("Received ping from agent");
77 }
78 Message::Pong => {
79 debug!("Received pong from agent");
81 }
82 Message::Error {
83 request_id,
84 code,
85 message: error_message,
86 } => {
87 if let Some(req_id) = request_id {
88 warn!(
89 "Received error for request {}: {:?} - {}",
90 req_id, code, error_message
91 );
92 handle_error_response(&clients.dynamodb, &req_id, code, &error_message).await?;
93 } else {
94 warn!("Received error without request ID: {}", error_message);
95 }
96 }
97 _ => {
98 warn!("Received unexpected message type");
99 }
100 }
101
102 Ok(ApiGatewayProxyResponse {
104 status_code: 200,
105 headers: Default::default(),
106 multi_value_headers: Default::default(),
107 body: None,
108 is_base64_encoded: false,
109 })
110}
111
112async fn handle_http_response(
114 client: &DynamoDbClient,
115 response: HttpResponse,
116) -> Result<(), Error> {
117 update_pending_request_with_response(client, &response)
118 .await
119 .map_err(|e| {
120 error!(
121 "Failed to update pending request {}: {}",
122 response.request_id, e
123 );
124 format!("Failed to update pending request: {}", e)
125 })?;
126
127 debug!(
128 "Successfully updated pending request: {}",
129 response.request_id
130 );
131
132 Ok(())
133}
134
135async fn handle_ready_message(
137 dynamodb_client: &DynamoDbClient,
138 apigw_management: &Option<aws_sdk_apigatewaymanagement::Client>,
139 connection_id: &str,
140) -> Result<(), Error> {
141 let table_name = std::env::var("CONNECTIONS_TABLE_NAME")
143 .map_err(|_| "CONNECTIONS_TABLE_NAME environment variable not set")?;
144
145 let result = dynamodb_client
146 .get_item()
147 .table_name(&table_name)
148 .key("connectionId", AttributeValue::S(connection_id.to_string()))
149 .send()
150 .await
151 .map_err(|e| {
152 error!(
153 "Failed to get connection metadata for {}: {}",
154 connection_id, e
155 );
156 format!("Failed to get connection metadata: {}", e)
157 })?;
158
159 let item = result.item.ok_or("Connection not found")?;
160
161 let tunnel_id = item
162 .get("tunnelId")
163 .and_then(|v| v.as_s().ok())
164 .ok_or("Missing tunnelId")?;
165
166 let public_url = item
167 .get("publicUrl")
168 .and_then(|v| v.as_s().ok())
169 .ok_or("Missing publicUrl")?;
170
171 if let Some(client) = apigw_management {
173 let message = Message::ConnectionEstablished {
174 connection_id: connection_id.to_string(),
175 tunnel_id: tunnel_id.clone(),
176 public_url: public_url.clone(),
177 };
178
179 let message_json = serde_json::to_string(&message)
180 .map_err(|e| format!("Failed to serialize ConnectionEstablished: {}", e))?;
181
182 info!(
183 "Sending ConnectionEstablished to {}: {}",
184 connection_id, message_json
185 );
186
187 let mut retry_count = 0;
190 let max_retries = 3;
191 let mut delay_ms = 100;
192
193 loop {
194 match client
195 .post_to_connection()
196 .connection_id(connection_id)
197 .data(Blob::new(message_json.as_bytes()))
198 .send()
199 .await
200 {
201 Ok(_) => {
202 info!(
203 "✅ Sent ConnectionEstablished to {} (attempt {})",
204 connection_id,
205 retry_count + 1
206 );
207 break;
208 }
209 Err(e) => {
210 retry_count += 1;
211 if retry_count >= max_retries {
212 error!(
213 "Failed to send ConnectionEstablished to {} after {} attempts: {}",
214 connection_id, max_retries, e
215 );
216 break;
218 }
219 warn!(
220 "Failed to send ConnectionEstablished (attempt {}), retrying in {}ms: {}",
221 retry_count, delay_ms, e
222 );
223 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
224 delay_ms *= 2; }
226 }
227 }
228 } else {
229 error!("API Gateway Management client not available");
230 }
231
232 Ok(())
233}
234
235async fn handle_error_response(
237 client: &DynamoDbClient,
238 request_id: &str,
239 code: ErrorCode,
240 message: &str,
241) -> Result<(), Error> {
242 let table_name = std::env::var("PENDING_REQUESTS_TABLE_NAME")
243 .map_err(|_| "PENDING_REQUESTS_TABLE_NAME environment variable not set")?;
244
245 let status_code = match code {
247 ErrorCode::InvalidRequest => 400,
248 ErrorCode::Timeout => 504,
249 ErrorCode::LocalServiceUnavailable => 503,
250 ErrorCode::InternalError => 502,
251 };
252
253 let error_response = HttpResponse {
254 request_id: request_id.to_string(),
255 status_code,
256 headers: [("Content-Type".to_string(), vec!["text/plain".to_string()])]
257 .into_iter()
258 .collect(),
259 body: encode_body(message.as_bytes()),
260 processing_time_ms: 0,
261 };
262
263 let response_data = serde_json::to_string(&error_response).map_err(|e| {
264 error!("Failed to serialize error response: {}", e);
265 format!("Failed to serialize error response: {}", e)
266 })?;
267
268 client
269 .update_item()
270 .table_name(&table_name)
271 .key("requestId", AttributeValue::S(request_id.to_string()))
272 .update_expression("SET #status = :status, responseData = :data")
273 .expression_attribute_names("#status", "status")
274 .expression_attribute_values(":status", AttributeValue::S("completed".to_string()))
275 .expression_attribute_values(":data", AttributeValue::S(response_data))
276 .send()
277 .await
278 .map_err(|e| {
279 error!(
280 "Failed to update pending request {} with error: {}",
281 request_id, e
282 );
283 format!("Failed to update pending request: {}", e)
284 })?;
285
286 debug!("Updated pending request with error: {}", request_id);
287
288 Ok(())
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_error_code_to_status_code() {
297 let codes = vec![
298 (ErrorCode::InvalidRequest, 400),
299 (ErrorCode::Timeout, 504),
300 (ErrorCode::LocalServiceUnavailable, 503),
301 (ErrorCode::InternalError, 502),
302 ];
303
304 for (error_code, expected_status) in codes {
305 let status = match error_code {
306 ErrorCode::InvalidRequest => 400,
307 ErrorCode::Timeout => 504,
308 ErrorCode::LocalServiceUnavailable => 503,
309 ErrorCode::InternalError => 502,
310 };
311 assert_eq!(status, expected_status);
312 }
313 }
314
315 #[test]
316 fn test_error_response_format() {
317 let error_response = HttpResponse {
318 request_id: "req_123".to_string(),
319 status_code: 502,
320 headers: [("Content-Type".to_string(), vec!["text/plain".to_string()])]
321 .into_iter()
322 .collect(),
323 body: encode_body(b"Service error"),
324 processing_time_ms: 0,
325 };
326
327 assert_eq!(error_response.status_code, 502);
328 assert_eq!(
329 error_response.headers.get("Content-Type").unwrap()[0],
330 "text/plain"
331 );
332 assert!(!error_response.body.is_empty());
333 }
334}