Skip to main content

codive_relay/routes/
proxy.rs

1//! HTTP proxy handler that routes requests to tunnels
2
3use axum::{
4    body::Body,
5    extract::State,
6    http::{Request, Response, StatusCode},
7    response::IntoResponse,
8};
9use base64::Engine;
10use bytes::Bytes;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio_stream::wrappers::ReceiverStream;
14use tracing::{debug, info, warn};
15
16use crate::state::RelayState;
17use crate::tunnel::WsMessage;
18
19/// Check if this is an SSE request
20fn is_sse_request(request: &Request<Body>) -> bool {
21    // Check Accept header
22    if let Some(accept) = request.headers().get("accept") {
23        if let Ok(accept_str) = accept.to_str() {
24            if accept_str.contains("text/event-stream") {
25                return true;
26            }
27        }
28    }
29
30    // Check path pattern (common SSE endpoint patterns)
31    let path = request.uri().path();
32    path.ends_with("/events") || path.contains("/events/") || path.ends_with("/stream")
33}
34
35/// Proxy handler that routes HTTP requests to the appropriate tunnel
36pub async fn proxy_handler(
37    State(state): State<Arc<RelayState>>,
38    request: Request<Body>,
39) -> impl IntoResponse {
40    // Extract host from request headers
41    let host = request
42        .headers()
43        .get(axum::http::header::HOST)
44        .and_then(|v| v.to_str().ok())
45        .unwrap_or("");
46
47    // Extract tunnel_id from subdomain
48    // Host format: {tunnel_id}.relay.example.com
49    let tunnel_id = match extract_tunnel_id(host, &state.config.base_domain) {
50        Some(id) => id,
51        None => {
52            return (
53                StatusCode::NOT_FOUND,
54                format!("No tunnel found for host: {}", host),
55            )
56                .into_response();
57        }
58    };
59
60    // Look up the tunnel
61    let tunnel = match state.get_tunnel(&tunnel_id) {
62        Some(t) => t,
63        None => {
64            return (
65                StatusCode::BAD_GATEWAY,
66                format!("Tunnel '{}' is not connected", tunnel_id),
67            )
68                .into_response();
69        }
70    };
71
72    // Check if this is an SSE request
73    let is_sse = is_sse_request(&request);
74
75    info!(
76        tunnel_id = %tunnel_id,
77        method = %request.method(),
78        uri = %request.uri(),
79        is_sse = %is_sse,
80        "Proxying request"
81    );
82
83    // Generate request ID
84    let request_id = uuid::Uuid::new_v4().to_string();
85
86    // Convert HTTP request to tunnel message
87    let method = request.method().to_string();
88    let path = request.uri().path().to_string();
89    let query = request.uri().query().map(|s| s.to_string());
90
91    let mut headers = std::collections::HashMap::new();
92    for (name, value) in request.headers() {
93        if let Ok(v) = value.to_str() {
94            headers.insert(name.to_string(), v.to_string());
95        }
96    }
97
98    // Read body
99    let body_bytes = match axum::body::to_bytes(request.into_body(), 10 * 1024 * 1024).await {
100        Ok(bytes) => bytes,
101        Err(e) => {
102            warn!(error = %e, "Failed to read request body");
103            return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response();
104        }
105    };
106
107    let body = if body_bytes.is_empty() {
108        None
109    } else {
110        Some(base64::engine::general_purpose::STANDARD.encode(&body_bytes))
111    };
112
113    // Create the data message
114    let data_msg = codive_tunnel::DataMessage::HttpRequest {
115        request_id: request_id.clone(),
116        client_id: "proxy".to_string(),
117        method,
118        path,
119        query,
120        headers,
121        body,
122    };
123
124    // Serialize the request
125    let msg_bytes = serde_json::to_vec(&data_msg).unwrap();
126    let wire_msg = codive_tunnel::WireMessage::encode_encrypted(
127        codive_tunnel::message_type::ENCRYPTED_REQUEST,
128        msg_bytes,
129    );
130
131    // IMPORTANT: Register pending request BEFORE sending to avoid race condition
132    // The tunnel client might respond before we're listening otherwise
133    enum ResponseReceiver {
134        Regular(tokio::sync::oneshot::Receiver<codive_tunnel::DataMessage>),
135        Streaming(tokio::sync::mpsc::Receiver<codive_tunnel::DataMessage>),
136    }
137
138    let response_rx = if is_sse {
139        ResponseReceiver::Streaming(tunnel.register_streaming_request(request_id.clone()))
140    } else {
141        ResponseReceiver::Regular(tunnel.register_request(request_id.clone()))
142    };
143
144    debug!(
145        request_id = %request_id,
146        tunnel_id = %tunnel_id,
147        is_sse = is_sse,
148        "Sending request to tunnel"
149    );
150
151    if tunnel.ws_sender.send(WsMessage::Binary(wire_msg)).await.is_err() {
152        tunnel.pending_requests.remove(&request_id);
153        return (StatusCode::BAD_GATEWAY, "Tunnel connection lost").into_response();
154    }
155
156    debug!(
157        request_id = %request_id,
158        "Request sent, waiting for response"
159    );
160
161    // Handle SSE vs regular requests differently
162    match response_rx {
163        ResponseReceiver::Streaming(rx) => {
164            handle_sse_response(tunnel, request_id, rx).await
165        }
166        ResponseReceiver::Regular(rx) => {
167            handle_regular_response(tunnel, request_id, rx).await
168        }
169    }
170}
171
172/// Handle regular (non-streaming) HTTP response
173async fn handle_regular_response(
174    tunnel: Arc<crate::tunnel::TunnelConnection>,
175    request_id: String,
176    response_rx: tokio::sync::oneshot::Receiver<codive_tunnel::DataMessage>,
177) -> Response<Body> {
178    // Wait for response with timeout
179    let timeout = Duration::from_secs(30);
180    match tokio::time::timeout(timeout, response_rx).await {
181        Ok(Ok(response_msg)) => {
182            // Convert tunnel response to HTTP response
183            build_http_response(response_msg)
184        }
185        Ok(Err(_)) => {
186            // Channel closed - tunnel disconnected
187            (StatusCode::BAD_GATEWAY, "Tunnel disconnected").into_response()
188        }
189        Err(_) => {
190            // Timeout
191            tunnel.pending_requests.remove(&request_id);
192            (StatusCode::GATEWAY_TIMEOUT, "Request timed out").into_response()
193        }
194    }
195}
196
197/// Handle SSE streaming response
198async fn handle_sse_response(
199    tunnel: Arc<crate::tunnel::TunnelConnection>,
200    request_id: String,
201    mut response_rx: tokio::sync::mpsc::Receiver<codive_tunnel::DataMessage>,
202) -> Response<Body> {
203    // Wait for the initial response with headers
204    let timeout = Duration::from_secs(30);
205    let initial_response = match tokio::time::timeout(timeout, response_rx.recv()).await {
206        Ok(Some(msg)) => msg,
207        Ok(None) => {
208            tunnel.complete_streaming_request(&request_id);
209            return (StatusCode::BAD_GATEWAY, "Tunnel disconnected").into_response();
210        }
211        Err(_) => {
212            tunnel.complete_streaming_request(&request_id);
213            return (StatusCode::GATEWAY_TIMEOUT, "Request timed out").into_response();
214        }
215    };
216
217    // Extract initial response headers
218    let (status, initial_headers, initial_body) = match initial_response {
219        codive_tunnel::DataMessage::HttpResponse {
220            status,
221            headers,
222            body,
223            streaming,
224            ..
225        } => {
226            if !streaming {
227                // Not actually a streaming response, handle as regular
228                tunnel.complete_streaming_request(&request_id);
229                let response_msg = codive_tunnel::DataMessage::HttpResponse {
230                    request_id: request_id.clone(),
231                    status,
232                    headers,
233                    body,
234                    streaming: false,
235                };
236                return build_http_response(response_msg);
237            }
238            (status, headers, body)
239        }
240        codive_tunnel::DataMessage::RequestError { message, .. } => {
241            tunnel.complete_streaming_request(&request_id);
242            return Response::builder()
243                .status(StatusCode::BAD_GATEWAY)
244                .body(Body::from(message))
245                .unwrap();
246        }
247        _ => {
248            tunnel.complete_streaming_request(&request_id);
249            return (StatusCode::INTERNAL_SERVER_ERROR, "Unexpected response type").into_response();
250        }
251    };
252
253    debug!(
254        request_id = %request_id,
255        status = %status,
256        "Starting SSE stream"
257    );
258
259    // Create a channel for the streaming body
260    let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Result<Bytes, std::io::Error>>(100);
261
262    // Send initial body chunk if present
263    if let Some(b64_body) = initial_body {
264        if let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(&b64_body) {
265            let _ = body_tx.send(Ok(Bytes::from(bytes))).await;
266        }
267    }
268
269    // Spawn task to forward chunks
270    let req_id = request_id.clone();
271    tokio::spawn(async move {
272        while let Some(msg) = response_rx.recv().await {
273            match msg {
274                codive_tunnel::DataMessage::HttpResponseChunk {
275                    chunk,
276                    is_final,
277                    ..
278                } => {
279                    // Decode and forward chunk
280                    if let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(&chunk) {
281                        if body_tx.send(Ok(Bytes::from(bytes))).await.is_err() {
282                            debug!(request_id = %req_id, "Client disconnected");
283                            break;
284                        }
285                    }
286
287                    if is_final {
288                        debug!(request_id = %req_id, "SSE stream completed");
289                        break;
290                    }
291                }
292                codive_tunnel::DataMessage::RequestError { message, .. } => {
293                    warn!(request_id = %req_id, error = %message, "SSE stream error");
294                    break;
295                }
296                _ => {
297                    // Ignore other message types in streaming context
298                }
299            }
300        }
301        // Channel will be dropped here, ending the stream
302    });
303
304    // Build response with streaming body
305    let status = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
306    let mut response = Response::builder().status(status);
307
308    for (name, value) in initial_headers {
309        response = response.header(name, value);
310    }
311
312    // Use ReceiverStream to convert mpsc receiver to stream
313    let body_stream = ReceiverStream::new(body_rx);
314    let body = Body::from_stream(body_stream);
315
316    response.body(body).unwrap_or_else(|_| {
317        Response::builder()
318            .status(StatusCode::INTERNAL_SERVER_ERROR)
319            .body(Body::from("Internal error"))
320            .unwrap()
321    })
322}
323
324/// Extract tunnel ID from the host header
325fn extract_tunnel_id(host: &str, base_domain: &str) -> Option<String> {
326    // Remove port if present
327    let host = host.split(':').next().unwrap_or(host);
328    let base = base_domain.split(':').next().unwrap_or(base_domain);
329
330    // Check if host ends with base domain
331    if let Some(prefix) = host.strip_suffix(base) {
332        // Remove trailing dot
333        let prefix = prefix.strip_suffix('.').unwrap_or(prefix);
334        if !prefix.is_empty() {
335            return Some(prefix.to_string());
336        }
337    }
338
339    // For development: if host matches tunnel_id.localhost pattern
340    if let Some(tunnel_id) = host.strip_suffix(".localhost") {
341        return Some(tunnel_id.to_string());
342    }
343
344    None
345}
346
347/// Build an HTTP response from a tunnel response message
348fn build_http_response(msg: codive_tunnel::DataMessage) -> Response<Body> {
349    match msg {
350        codive_tunnel::DataMessage::HttpResponse {
351            status,
352            headers,
353            body,
354            ..
355        } => {
356            let status = StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
357
358            let mut response = Response::builder().status(status);
359
360            for (name, value) in headers {
361                response = response.header(name, value);
362            }
363
364            let body = if let Some(b64_body) = body {
365                match base64::engine::general_purpose::STANDARD.decode(&b64_body) {
366                    Ok(bytes) => Body::from(bytes),
367                    Err(_) => Body::from(b64_body), // Fallback to raw if not base64
368                }
369            } else {
370                Body::empty()
371            };
372
373            response.body(body).unwrap_or_else(|_| {
374                Response::builder()
375                    .status(StatusCode::INTERNAL_SERVER_ERROR)
376                    .body(Body::from("Internal error"))
377                    .unwrap()
378            })
379        }
380        codive_tunnel::DataMessage::RequestError { message, .. } => Response::builder()
381            .status(StatusCode::BAD_GATEWAY)
382            .body(Body::from(message))
383            .unwrap(),
384        _ => Response::builder()
385            .status(StatusCode::INTERNAL_SERVER_ERROR)
386            .body(Body::from("Unexpected response type"))
387            .unwrap(),
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_extract_tunnel_id() {
397        assert_eq!(
398            extract_tunnel_id("abc123.relay.example.com", "relay.example.com"),
399            Some("abc123".to_string())
400        );
401
402        assert_eq!(
403            extract_tunnel_id("abc123.relay.example.com:3001", "relay.example.com:3001"),
404            Some("abc123".to_string())
405        );
406
407        assert_eq!(
408            extract_tunnel_id("abc123.localhost", "localhost"),
409            Some("abc123".to_string())
410        );
411
412        assert_eq!(
413            extract_tunnel_id("relay.example.com", "relay.example.com"),
414            None
415        );
416    }
417}