http_tunnel_handler/handlers/
response.rs

1//! ResponseHandler - Handles WebSocket $default route
2//!
3//! This module processes messages from the agent, including HTTP responses,
4//! error messages, and ping/pong heartbeats. It updates the pending request status
5//! in DynamoDB so the ForwardingHandler can complete the HTTP request.
6
7use aws_lambda_events::apigw::ApiGatewayProxyResponse;
8use aws_sdk_dynamodb::Client as DynamoDbClient;
9use aws_sdk_dynamodb::types::AttributeValue;
10use http_tunnel_common::encode_body;
11use http_tunnel_common::protocol::{ErrorCode, HttpResponse, Message};
12use lambda_runtime::{Error, LambdaEvent};
13use serde::{Deserialize, Serialize};
14use tracing::{debug, error, info, warn};
15
16use crate::{SharedClients, update_pending_request_with_response};
17use aws_sdk_apigatewaymanagement::primitives::Blob;
18
19/// WebSocket $default event structure (messages from agent)
20/// This is different from $connect/$disconnect events - it doesn't have connectedAt
21#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
22#[serde(rename_all = "camelCase")]
23pub struct WebSocketMessageEvent {
24    pub request_context: WebSocketMessageRequestContext,
25    pub body: Option<String>,
26    pub is_base64_encoded: Option<bool>,
27}
28
29#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
30#[serde(rename_all = "camelCase")]
31pub struct WebSocketMessageRequestContext {
32    pub route_key: String,
33    #[serde(default)]
34    pub event_type: Option<String>,
35    pub connection_id: String,
36    pub request_id: String,
37    pub domain_name: Option<String>,
38    pub stage: Option<String>,
39    pub api_id: Option<String>,
40    #[serde(default)]
41    pub connected_at: Option<i64>,
42}
43
44/// Handler for WebSocket $default route (messages from agent)
45pub async fn handle_response(
46    event: LambdaEvent<WebSocketMessageEvent>,
47    clients: &SharedClients,
48) -> Result<ApiGatewayProxyResponse, Error> {
49    let body = event.payload.body.ok_or("Missing message body")?;
50
51    debug!("Received message from agent: {}", body);
52
53    // Parse message
54    let message: Message = serde_json::from_str(&body).map_err(|e| {
55        error!("Failed to parse message: {}", e);
56        format!("Invalid message format: {}", e)
57    })?;
58
59    let connection_id = &event.payload.request_context.connection_id;
60
61    match message {
62        Message::Ready => {
63            info!("Received Ready message from agent, sending ConnectionEstablished");
64            handle_ready_message(&clients.dynamodb, &clients.apigw_management, connection_id)
65                .await?;
66        }
67        Message::HttpResponse(response) => {
68            info!(
69                "Received HTTP response for request {}: status {}",
70                response.request_id, response.status_code
71            );
72            handle_http_response(&clients.dynamodb, response).await?;
73        }
74        Message::Ping => {
75            // Heartbeat received, no action needed
76            debug!("Received ping from agent");
77        }
78        Message::Pong => {
79            // Pong received, no action needed
80            debug!("Received pong from agent");
81        }
82        Message::Error {
83            request_id,
84            code,
85            message: error_message,
86        } => {
87            if let Some(req_id) = request_id {
88                warn!(
89                    "Received error for request {}: {:?} - {}",
90                    req_id, code, error_message
91                );
92                handle_error_response(&clients.dynamodb, &req_id, code, &error_message).await?;
93            } else {
94                warn!("Received error without request ID: {}", error_message);
95            }
96        }
97        _ => {
98            warn!("Received unexpected message type");
99        }
100    }
101
102    // Always return success
103    Ok(ApiGatewayProxyResponse {
104        status_code: 200,
105        headers: Default::default(),
106        multi_value_headers: Default::default(),
107        body: None,
108        is_base64_encoded: false,
109    })
110}
111
112/// Handle HTTP response from agent
113async fn handle_http_response(
114    client: &DynamoDbClient,
115    response: HttpResponse,
116) -> Result<(), Error> {
117    update_pending_request_with_response(client, &response)
118        .await
119        .map_err(|e| {
120            error!(
121                "Failed to update pending request {}: {}",
122                response.request_id, e
123            );
124            format!("Failed to update pending request: {}", e)
125        })?;
126
127    debug!(
128        "Successfully updated pending request: {}",
129        response.request_id
130    );
131
132    Ok(())
133}
134
135/// Handle Ready message from agent - send back ConnectionEstablished with public URL
136async fn handle_ready_message(
137    dynamodb_client: &DynamoDbClient,
138    apigw_management: &Option<aws_sdk_apigatewaymanagement::Client>,
139    connection_id: &str,
140) -> Result<(), Error> {
141    // Look up connection metadata from DynamoDB
142    let table_name = std::env::var("CONNECTIONS_TABLE_NAME")
143        .map_err(|_| "CONNECTIONS_TABLE_NAME environment variable not set")?;
144
145    let result = dynamodb_client
146        .get_item()
147        .table_name(&table_name)
148        .key("connectionId", AttributeValue::S(connection_id.to_string()))
149        .send()
150        .await
151        .map_err(|e| {
152            error!(
153                "Failed to get connection metadata for {}: {}",
154                connection_id, e
155            );
156            format!("Failed to get connection metadata: {}", e)
157        })?;
158
159    let item = result.item.ok_or("Connection not found")?;
160
161    let tunnel_id = item
162        .get("tunnelId")
163        .and_then(|v| v.as_s().ok())
164        .ok_or("Missing tunnelId")?;
165
166    let public_url = item
167        .get("publicUrl")
168        .and_then(|v| v.as_s().ok())
169        .ok_or("Missing publicUrl")?;
170
171    // Send ConnectionEstablished message
172    if let Some(client) = apigw_management {
173        let message = Message::ConnectionEstablished {
174            connection_id: connection_id.to_string(),
175            tunnel_id: tunnel_id.clone(),
176            public_url: public_url.clone(),
177        };
178
179        let message_json = serde_json::to_string(&message)
180            .map_err(|e| format!("Failed to serialize ConnectionEstablished: {}", e))?;
181
182        info!(
183            "Sending ConnectionEstablished to {}: {}",
184            connection_id, message_json
185        );
186
187        // Retry logic with exponential backoff for WebSocket dispatch failures
188        // API Gateway WebSocket connections may not be immediately ready to receive messages
189        let mut retry_count = 0;
190        let max_retries = 3;
191        let mut delay_ms = 100;
192
193        loop {
194            match client
195                .post_to_connection()
196                .connection_id(connection_id)
197                .data(Blob::new(message_json.as_bytes()))
198                .send()
199                .await
200            {
201                Ok(_) => {
202                    info!(
203                        "✅ Sent ConnectionEstablished to {} (attempt {})",
204                        connection_id,
205                        retry_count + 1
206                    );
207                    break;
208                }
209                Err(e) => {
210                    retry_count += 1;
211                    if retry_count >= max_retries {
212                        error!(
213                            "Failed to send ConnectionEstablished to {} after {} attempts: {}",
214                            connection_id, max_retries, e
215                        );
216                        // Don't fail the request - connection is established, client will timeout and retry
217                        break;
218                    }
219                    warn!(
220                        "Failed to send ConnectionEstablished (attempt {}), retrying in {}ms: {}",
221                        retry_count, delay_ms, e
222                    );
223                    tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
224                    delay_ms *= 2; // Exponential backoff
225                }
226            }
227        }
228    } else {
229        error!("API Gateway Management client not available");
230    }
231
232    Ok(())
233}
234
235/// Handle error response from agent
236async fn handle_error_response(
237    client: &DynamoDbClient,
238    request_id: &str,
239    code: ErrorCode,
240    message: &str,
241) -> Result<(), Error> {
242    let table_name = std::env::var("PENDING_REQUESTS_TABLE_NAME")
243        .map_err(|_| "PENDING_REQUESTS_TABLE_NAME environment variable not set")?;
244
245    // Create error response with appropriate status code
246    let status_code = match code {
247        ErrorCode::InvalidRequest => 400,
248        ErrorCode::Timeout => 504,
249        ErrorCode::LocalServiceUnavailable => 503,
250        ErrorCode::InternalError => 502,
251    };
252
253    let error_response = HttpResponse {
254        request_id: request_id.to_string(),
255        status_code,
256        headers: [("Content-Type".to_string(), vec!["text/plain".to_string()])]
257            .into_iter()
258            .collect(),
259        body: encode_body(message.as_bytes()),
260        processing_time_ms: 0,
261    };
262
263    let response_data = serde_json::to_string(&error_response).map_err(|e| {
264        error!("Failed to serialize error response: {}", e);
265        format!("Failed to serialize error response: {}", e)
266    })?;
267
268    client
269        .update_item()
270        .table_name(&table_name)
271        .key("requestId", AttributeValue::S(request_id.to_string()))
272        .update_expression("SET #status = :status, responseData = :data")
273        .expression_attribute_names("#status", "status")
274        .expression_attribute_values(":status", AttributeValue::S("completed".to_string()))
275        .expression_attribute_values(":data", AttributeValue::S(response_data))
276        .send()
277        .await
278        .map_err(|e| {
279            error!(
280                "Failed to update pending request {} with error: {}",
281                request_id, e
282            );
283            format!("Failed to update pending request: {}", e)
284        })?;
285
286    debug!("Updated pending request with error: {}", request_id);
287
288    Ok(())
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_error_code_to_status_code() {
297        let codes = vec![
298            (ErrorCode::InvalidRequest, 400),
299            (ErrorCode::Timeout, 504),
300            (ErrorCode::LocalServiceUnavailable, 503),
301            (ErrorCode::InternalError, 502),
302        ];
303
304        for (error_code, expected_status) in codes {
305            let status = match error_code {
306                ErrorCode::InvalidRequest => 400,
307                ErrorCode::Timeout => 504,
308                ErrorCode::LocalServiceUnavailable => 503,
309                ErrorCode::InternalError => 502,
310            };
311            assert_eq!(status, expected_status);
312        }
313    }
314
315    #[test]
316    fn test_error_response_format() {
317        let error_response = HttpResponse {
318            request_id: "req_123".to_string(),
319            status_code: 502,
320            headers: [("Content-Type".to_string(), vec!["text/plain".to_string()])]
321                .into_iter()
322                .collect(),
323            body: encode_body(b"Service error"),
324            processing_time_ms: 0,
325        };
326
327        assert_eq!(error_response.status_code, 502);
328        assert_eq!(
329            error_response.headers.get("Content-Type").unwrap()[0],
330            "text/plain"
331        );
332        assert!(!error_response.body.is_empty());
333    }
334}