1use axum::{
4 body::Body,
5 extract::State,
6 http::{Request, Response, StatusCode},
7 response::IntoResponse,
8};
9use base64::Engine;
10use bytes::Bytes;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio_stream::wrappers::ReceiverStream;
14use tracing::{debug, info, warn};
15
16use crate::state::RelayState;
17use crate::tunnel::WsMessage;
18
19fn is_sse_request(request: &Request<Body>) -> bool {
21 if let Some(accept) = request.headers().get("accept") {
23 if let Ok(accept_str) = accept.to_str() {
24 if accept_str.contains("text/event-stream") {
25 return true;
26 }
27 }
28 }
29
30 let path = request.uri().path();
32 path.ends_with("/events") || path.contains("/events/") || path.ends_with("/stream")
33}
34
35pub async fn proxy_handler(
37 State(state): State<Arc<RelayState>>,
38 request: Request<Body>,
39) -> impl IntoResponse {
40 let host = request
42 .headers()
43 .get(axum::http::header::HOST)
44 .and_then(|v| v.to_str().ok())
45 .unwrap_or("");
46
47 let tunnel_id = match extract_tunnel_id(host, &state.config.base_domain) {
50 Some(id) => id,
51 None => {
52 return (
53 StatusCode::NOT_FOUND,
54 format!("No tunnel found for host: {}", host),
55 )
56 .into_response();
57 }
58 };
59
60 let tunnel = match state.get_tunnel(&tunnel_id) {
62 Some(t) => t,
63 None => {
64 return (
65 StatusCode::BAD_GATEWAY,
66 format!("Tunnel '{}' is not connected", tunnel_id),
67 )
68 .into_response();
69 }
70 };
71
72 let is_sse = is_sse_request(&request);
74
75 info!(
76 tunnel_id = %tunnel_id,
77 method = %request.method(),
78 uri = %request.uri(),
79 is_sse = %is_sse,
80 "Proxying request"
81 );
82
83 let request_id = uuid::Uuid::new_v4().to_string();
85
86 let method = request.method().to_string();
88 let path = request.uri().path().to_string();
89 let query = request.uri().query().map(|s| s.to_string());
90
91 let mut headers = std::collections::HashMap::new();
92 for (name, value) in request.headers() {
93 if let Ok(v) = value.to_str() {
94 headers.insert(name.to_string(), v.to_string());
95 }
96 }
97
98 let body_bytes = match axum::body::to_bytes(request.into_body(), 10 * 1024 * 1024).await {
100 Ok(bytes) => bytes,
101 Err(e) => {
102 warn!(error = %e, "Failed to read request body");
103 return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response();
104 }
105 };
106
107 let body = if body_bytes.is_empty() {
108 None
109 } else {
110 Some(base64::engine::general_purpose::STANDARD.encode(&body_bytes))
111 };
112
113 let data_msg = codive_tunnel::DataMessage::HttpRequest {
115 request_id: request_id.clone(),
116 client_id: "proxy".to_string(),
117 method,
118 path,
119 query,
120 headers,
121 body,
122 };
123
124 let msg_bytes = serde_json::to_vec(&data_msg).unwrap();
126 let wire_msg = codive_tunnel::WireMessage::encode_encrypted(
127 codive_tunnel::message_type::ENCRYPTED_REQUEST,
128 msg_bytes,
129 );
130
131 enum ResponseReceiver {
134 Regular(tokio::sync::oneshot::Receiver<codive_tunnel::DataMessage>),
135 Streaming(tokio::sync::mpsc::Receiver<codive_tunnel::DataMessage>),
136 }
137
138 let response_rx = if is_sse {
139 ResponseReceiver::Streaming(tunnel.register_streaming_request(request_id.clone()))
140 } else {
141 ResponseReceiver::Regular(tunnel.register_request(request_id.clone()))
142 };
143
144 debug!(
145 request_id = %request_id,
146 tunnel_id = %tunnel_id,
147 is_sse = is_sse,
148 "Sending request to tunnel"
149 );
150
151 if tunnel.ws_sender.send(WsMessage::Binary(wire_msg)).await.is_err() {
152 tunnel.pending_requests.remove(&request_id);
153 return (StatusCode::BAD_GATEWAY, "Tunnel connection lost").into_response();
154 }
155
156 debug!(
157 request_id = %request_id,
158 "Request sent, waiting for response"
159 );
160
161 match response_rx {
163 ResponseReceiver::Streaming(rx) => {
164 handle_sse_response(tunnel, request_id, rx).await
165 }
166 ResponseReceiver::Regular(rx) => {
167 handle_regular_response(tunnel, request_id, rx).await
168 }
169 }
170}
171
172async fn handle_regular_response(
174 tunnel: Arc<crate::tunnel::TunnelConnection>,
175 request_id: String,
176 response_rx: tokio::sync::oneshot::Receiver<codive_tunnel::DataMessage>,
177) -> Response<Body> {
178 let timeout = Duration::from_secs(30);
180 match tokio::time::timeout(timeout, response_rx).await {
181 Ok(Ok(response_msg)) => {
182 build_http_response(response_msg)
184 }
185 Ok(Err(_)) => {
186 (StatusCode::BAD_GATEWAY, "Tunnel disconnected").into_response()
188 }
189 Err(_) => {
190 tunnel.pending_requests.remove(&request_id);
192 (StatusCode::GATEWAY_TIMEOUT, "Request timed out").into_response()
193 }
194 }
195}
196
197async fn handle_sse_response(
199 tunnel: Arc<crate::tunnel::TunnelConnection>,
200 request_id: String,
201 mut response_rx: tokio::sync::mpsc::Receiver<codive_tunnel::DataMessage>,
202) -> Response<Body> {
203 let timeout = Duration::from_secs(30);
205 let initial_response = match tokio::time::timeout(timeout, response_rx.recv()).await {
206 Ok(Some(msg)) => msg,
207 Ok(None) => {
208 tunnel.complete_streaming_request(&request_id);
209 return (StatusCode::BAD_GATEWAY, "Tunnel disconnected").into_response();
210 }
211 Err(_) => {
212 tunnel.complete_streaming_request(&request_id);
213 return (StatusCode::GATEWAY_TIMEOUT, "Request timed out").into_response();
214 }
215 };
216
217 let (status, initial_headers, initial_body) = match initial_response {
219 codive_tunnel::DataMessage::HttpResponse {
220 status,
221 headers,
222 body,
223 streaming,
224 ..
225 } => {
226 if !streaming {
227 tunnel.complete_streaming_request(&request_id);
229 let response_msg = codive_tunnel::DataMessage::HttpResponse {
230 request_id: request_id.clone(),
231 status,
232 headers,
233 body,
234 streaming: false,
235 };
236 return build_http_response(response_msg);
237 }
238 (status, headers, body)
239 }
240 codive_tunnel::DataMessage::RequestError { message, .. } => {
241 tunnel.complete_streaming_request(&request_id);
242 return Response::builder()
243 .status(StatusCode::BAD_GATEWAY)
244 .body(Body::from(message))
245 .unwrap();
246 }
247 _ => {
248 tunnel.complete_streaming_request(&request_id);
249 return (StatusCode::INTERNAL_SERVER_ERROR, "Unexpected response type").into_response();
250 }
251 };
252
253 debug!(
254 request_id = %request_id,
255 status = %status,
256 "Starting SSE stream"
257 );
258
259 let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Result<Bytes, std::io::Error>>(100);
261
262 if let Some(b64_body) = initial_body {
264 if let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(&b64_body) {
265 let _ = body_tx.send(Ok(Bytes::from(bytes))).await;
266 }
267 }
268
269 let req_id = request_id.clone();
271 tokio::spawn(async move {
272 while let Some(msg) = response_rx.recv().await {
273 match msg {
274 codive_tunnel::DataMessage::HttpResponseChunk {
275 chunk,
276 is_final,
277 ..
278 } => {
279 if let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(&chunk) {
281 if body_tx.send(Ok(Bytes::from(bytes))).await.is_err() {
282 debug!(request_id = %req_id, "Client disconnected");
283 break;
284 }
285 }
286
287 if is_final {
288 debug!(request_id = %req_id, "SSE stream completed");
289 break;
290 }
291 }
292 codive_tunnel::DataMessage::RequestError { message, .. } => {
293 warn!(request_id = %req_id, error = %message, "SSE stream error");
294 break;
295 }
296 _ => {
297 }
299 }
300 }
301 });
303
304 let status = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
306 let mut response = Response::builder().status(status);
307
308 for (name, value) in initial_headers {
309 response = response.header(name, value);
310 }
311
312 let body_stream = ReceiverStream::new(body_rx);
314 let body = Body::from_stream(body_stream);
315
316 response.body(body).unwrap_or_else(|_| {
317 Response::builder()
318 .status(StatusCode::INTERNAL_SERVER_ERROR)
319 .body(Body::from("Internal error"))
320 .unwrap()
321 })
322}
323
324fn extract_tunnel_id(host: &str, base_domain: &str) -> Option<String> {
326 let host = host.split(':').next().unwrap_or(host);
328 let base = base_domain.split(':').next().unwrap_or(base_domain);
329
330 if let Some(prefix) = host.strip_suffix(base) {
332 let prefix = prefix.strip_suffix('.').unwrap_or(prefix);
334 if !prefix.is_empty() {
335 return Some(prefix.to_string());
336 }
337 }
338
339 if let Some(tunnel_id) = host.strip_suffix(".localhost") {
341 return Some(tunnel_id.to_string());
342 }
343
344 None
345}
346
347fn build_http_response(msg: codive_tunnel::DataMessage) -> Response<Body> {
349 match msg {
350 codive_tunnel::DataMessage::HttpResponse {
351 status,
352 headers,
353 body,
354 ..
355 } => {
356 let status = StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
357
358 let mut response = Response::builder().status(status);
359
360 for (name, value) in headers {
361 response = response.header(name, value);
362 }
363
364 let body = if let Some(b64_body) = body {
365 match base64::engine::general_purpose::STANDARD.decode(&b64_body) {
366 Ok(bytes) => Body::from(bytes),
367 Err(_) => Body::from(b64_body), }
369 } else {
370 Body::empty()
371 };
372
373 response.body(body).unwrap_or_else(|_| {
374 Response::builder()
375 .status(StatusCode::INTERNAL_SERVER_ERROR)
376 .body(Body::from("Internal error"))
377 .unwrap()
378 })
379 }
380 codive_tunnel::DataMessage::RequestError { message, .. } => Response::builder()
381 .status(StatusCode::BAD_GATEWAY)
382 .body(Body::from(message))
383 .unwrap(),
384 _ => Response::builder()
385 .status(StatusCode::INTERNAL_SERVER_ERROR)
386 .body(Body::from("Unexpected response type"))
387 .unwrap(),
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_extract_tunnel_id() {
397 assert_eq!(
398 extract_tunnel_id("abc123.relay.example.com", "relay.example.com"),
399 Some("abc123".to_string())
400 );
401
402 assert_eq!(
403 extract_tunnel_id("abc123.relay.example.com:3001", "relay.example.com:3001"),
404 Some("abc123".to_string())
405 );
406
407 assert_eq!(
408 extract_tunnel_id("abc123.localhost", "localhost"),
409 Some("abc123".to_string())
410 );
411
412 assert_eq!(
413 extract_tunnel_id("relay.example.com", "relay.example.com"),
414 None
415 );
416 }
417}