use axum::{
body::Body,
extract::State,
http::{Request, Response, StatusCode},
response::IntoResponse,
};
use base64::Engine;
use bytes::Bytes;
use std::sync::Arc;
use std::time::Duration;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, info, warn};
use crate::state::RelayState;
use crate::tunnel::WsMessage;
fn is_sse_request(request: &Request<Body>) -> bool {
if let Some(accept) = request.headers().get("accept") {
if let Ok(accept_str) = accept.to_str() {
if accept_str.contains("text/event-stream") {
return true;
}
}
}
let path = request.uri().path();
path.ends_with("/events") || path.contains("/events/") || path.ends_with("/stream")
}
pub async fn proxy_handler(
State(state): State<Arc<RelayState>>,
request: Request<Body>,
) -> impl IntoResponse {
let host = request
.headers()
.get(axum::http::header::HOST)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let tunnel_id = match extract_tunnel_id(host, &state.config.base_domain) {
Some(id) => id,
None => {
return (
StatusCode::NOT_FOUND,
format!("No tunnel found for host: {}", host),
)
.into_response();
}
};
let tunnel = match state.get_tunnel(&tunnel_id) {
Some(t) => t,
None => {
return (
StatusCode::BAD_GATEWAY,
format!("Tunnel '{}' is not connected", tunnel_id),
)
.into_response();
}
};
let is_sse = is_sse_request(&request);
info!(
tunnel_id = %tunnel_id,
method = %request.method(),
uri = %request.uri(),
is_sse = %is_sse,
"Proxying request"
);
let request_id = uuid::Uuid::new_v4().to_string();
let method = request.method().to_string();
let path = request.uri().path().to_string();
let query = request.uri().query().map(|s| s.to_string());
let mut headers = std::collections::HashMap::new();
for (name, value) in request.headers() {
if let Ok(v) = value.to_str() {
headers.insert(name.to_string(), v.to_string());
}
}
let body_bytes = match axum::body::to_bytes(request.into_body(), 10 * 1024 * 1024).await {
Ok(bytes) => bytes,
Err(e) => {
warn!(error = %e, "Failed to read request body");
return (StatusCode::BAD_REQUEST, "Failed to read request body").into_response();
}
};
let body = if body_bytes.is_empty() {
None
} else {
Some(base64::engine::general_purpose::STANDARD.encode(&body_bytes))
};
let data_msg = codive_tunnel::DataMessage::HttpRequest {
request_id: request_id.clone(),
client_id: "proxy".to_string(),
method,
path,
query,
headers,
body,
};
let msg_bytes = serde_json::to_vec(&data_msg).unwrap();
let wire_msg = codive_tunnel::WireMessage::encode_encrypted(
codive_tunnel::message_type::ENCRYPTED_REQUEST,
msg_bytes,
);
enum ResponseReceiver {
Regular(tokio::sync::oneshot::Receiver<codive_tunnel::DataMessage>),
Streaming(tokio::sync::mpsc::Receiver<codive_tunnel::DataMessage>),
}
let response_rx = if is_sse {
ResponseReceiver::Streaming(tunnel.register_streaming_request(request_id.clone()))
} else {
ResponseReceiver::Regular(tunnel.register_request(request_id.clone()))
};
debug!(
request_id = %request_id,
tunnel_id = %tunnel_id,
is_sse = is_sse,
"Sending request to tunnel"
);
if tunnel.ws_sender.send(WsMessage::Binary(wire_msg)).await.is_err() {
tunnel.pending_requests.remove(&request_id);
return (StatusCode::BAD_GATEWAY, "Tunnel connection lost").into_response();
}
debug!(
request_id = %request_id,
"Request sent, waiting for response"
);
match response_rx {
ResponseReceiver::Streaming(rx) => {
handle_sse_response(tunnel, request_id, rx).await
}
ResponseReceiver::Regular(rx) => {
handle_regular_response(tunnel, request_id, rx).await
}
}
}
async fn handle_regular_response(
tunnel: Arc<crate::tunnel::TunnelConnection>,
request_id: String,
response_rx: tokio::sync::oneshot::Receiver<codive_tunnel::DataMessage>,
) -> Response<Body> {
let timeout = Duration::from_secs(30);
match tokio::time::timeout(timeout, response_rx).await {
Ok(Ok(response_msg)) => {
build_http_response(response_msg)
}
Ok(Err(_)) => {
(StatusCode::BAD_GATEWAY, "Tunnel disconnected").into_response()
}
Err(_) => {
tunnel.pending_requests.remove(&request_id);
(StatusCode::GATEWAY_TIMEOUT, "Request timed out").into_response()
}
}
}
async fn handle_sse_response(
tunnel: Arc<crate::tunnel::TunnelConnection>,
request_id: String,
mut response_rx: tokio::sync::mpsc::Receiver<codive_tunnel::DataMessage>,
) -> Response<Body> {
let timeout = Duration::from_secs(30);
let initial_response = match tokio::time::timeout(timeout, response_rx.recv()).await {
Ok(Some(msg)) => msg,
Ok(None) => {
tunnel.complete_streaming_request(&request_id);
return (StatusCode::BAD_GATEWAY, "Tunnel disconnected").into_response();
}
Err(_) => {
tunnel.complete_streaming_request(&request_id);
return (StatusCode::GATEWAY_TIMEOUT, "Request timed out").into_response();
}
};
let (status, initial_headers, initial_body) = match initial_response {
codive_tunnel::DataMessage::HttpResponse {
status,
headers,
body,
streaming,
..
} => {
if !streaming {
tunnel.complete_streaming_request(&request_id);
let response_msg = codive_tunnel::DataMessage::HttpResponse {
request_id: request_id.clone(),
status,
headers,
body,
streaming: false,
};
return build_http_response(response_msg);
}
(status, headers, body)
}
codive_tunnel::DataMessage::RequestError { message, .. } => {
tunnel.complete_streaming_request(&request_id);
return Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from(message))
.unwrap();
}
_ => {
tunnel.complete_streaming_request(&request_id);
return (StatusCode::INTERNAL_SERVER_ERROR, "Unexpected response type").into_response();
}
};
debug!(
request_id = %request_id,
status = %status,
"Starting SSE stream"
);
let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Result<Bytes, std::io::Error>>(100);
if let Some(b64_body) = initial_body {
if let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(&b64_body) {
let _ = body_tx.send(Ok(Bytes::from(bytes))).await;
}
}
let req_id = request_id.clone();
tokio::spawn(async move {
while let Some(msg) = response_rx.recv().await {
match msg {
codive_tunnel::DataMessage::HttpResponseChunk {
chunk,
is_final,
..
} => {
if let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(&chunk) {
if body_tx.send(Ok(Bytes::from(bytes))).await.is_err() {
debug!(request_id = %req_id, "Client disconnected");
break;
}
}
if is_final {
debug!(request_id = %req_id, "SSE stream completed");
break;
}
}
codive_tunnel::DataMessage::RequestError { message, .. } => {
warn!(request_id = %req_id, error = %message, "SSE stream error");
break;
}
_ => {
}
}
}
});
let status = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
let mut response = Response::builder().status(status);
for (name, value) in initial_headers {
response = response.header(name, value);
}
let body_stream = ReceiverStream::new(body_rx);
let body = Body::from_stream(body_stream);
response.body(body).unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Internal error"))
.unwrap()
})
}
fn extract_tunnel_id(host: &str, base_domain: &str) -> Option<String> {
let host = host.split(':').next().unwrap_or(host);
let base = base_domain.split(':').next().unwrap_or(base_domain);
if let Some(prefix) = host.strip_suffix(base) {
let prefix = prefix.strip_suffix('.').unwrap_or(prefix);
if !prefix.is_empty() {
return Some(prefix.to_string());
}
}
if let Some(tunnel_id) = host.strip_suffix(".localhost") {
return Some(tunnel_id.to_string());
}
None
}
fn build_http_response(msg: codive_tunnel::DataMessage) -> Response<Body> {
match msg {
codive_tunnel::DataMessage::HttpResponse {
status,
headers,
body,
..
} => {
let status = StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let mut response = Response::builder().status(status);
for (name, value) in headers {
response = response.header(name, value);
}
let body = if let Some(b64_body) = body {
match base64::engine::general_purpose::STANDARD.decode(&b64_body) {
Ok(bytes) => Body::from(bytes),
Err(_) => Body::from(b64_body), }
} else {
Body::empty()
};
response.body(body).unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Internal error"))
.unwrap()
})
}
codive_tunnel::DataMessage::RequestError { message, .. } => Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from(message))
.unwrap(),
_ => Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Unexpected response type"))
.unwrap(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_tunnel_id() {
assert_eq!(
extract_tunnel_id("abc123.relay.example.com", "relay.example.com"),
Some("abc123".to_string())
);
assert_eq!(
extract_tunnel_id("abc123.relay.example.com:3001", "relay.example.com:3001"),
Some("abc123".to_string())
);
assert_eq!(
extract_tunnel_id("abc123.localhost", "localhost"),
Some("abc123".to_string())
);
assert_eq!(
extract_tunnel_id("relay.example.com", "relay.example.com"),
None
);
}
}