use anyhow::Result;
use bytes::BytesMut;
use http::{Method, Request, Response, StatusCode};
use http_body_util::Full;
use hyper::body::{Bytes, Incoming};
use hyper::service::service_fn;
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::server::TlsStream;
use tracing::{debug, error, info, warn};
use crate::auth::AuthError;
use crate::config::Config;
use crate::rate_limiter::RateLimitError;
fn parse_authority(authority: &str) -> Result<(String, u16), String> {
let parts: Vec<&str> = authority.rsplitn(2, ':').collect();
if parts.len() != 2 {
return Err("Authority must be in host:port format".to_string());
}
let port_str = parts[0];
let host = parts[1];
if host.is_empty() {
return Err("Host cannot be empty".to_string());
}
let port: u16 = port_str.parse().map_err(|_| {
format!(
"Invalid port '{}': must be a number between 1 and 65535",
port_str
)
})?;
if port == 0 {
return Err("Invalid port: must be between 1 and 65535".to_string());
}
Ok((host.to_string(), port))
}
pub async fn serve_h2(tls_stream: TlsStream<TcpStream>, config: Arc<Config>) -> Result<()> {
info!("HTTP/2 connection handler started");
let mut h2_conn = h2::server::Builder::new()
.initial_window_size(65535) .initial_connection_window_size(1024 * 1024) .max_concurrent_streams(100) .max_frame_size(16384) .handshake(tls_stream)
.await
.map_err(|e| anyhow::anyhow!("HTTP/2 handshake failed: {}", e))?;
info!("HTTP/2 handshake complete, accepting streams");
while let Some(result) = h2_conn.accept().await {
match result {
Ok((request, mut respond)) => {
let config = Arc::clone(&config);
tokio::spawn(async move {
let method = request.method().clone();
let uri = request.uri().clone();
let result = if method == Method::CONNECT {
handle_h2_connect(request, respond, config).await
} else if config.http_proxy_enabled {
handle_h2_http_request(request, respond, config).await
} else {
info!("[H2] Non-CONNECT {} request for {} - HTTP forwarding disabled, returning 204", method, uri);
let response = Response::builder()
.status(StatusCode::NO_CONTENT)
.body(())
.unwrap();
match respond.send_response(response, true) {
Ok(_) => Ok(()),
Err(e) => {
error!("[H2] Failed to send stub response: {}", e);
Err(anyhow::anyhow!("Failed to send response: {}", e))
}
}
};
if let Err(e) = result {
error!("[H2] Stream handler error: {}", e);
}
});
}
Err(e) => {
error!("[H2] Error accepting stream: {}", e);
break;
}
}
}
info!("HTTP/2 connection closed");
Ok(())
}
async fn handle_h2_http_request(
request: Request<h2::RecvStream>,
mut respond: h2::server::SendResponse<Bytes>,
config: Arc<Config>,
) -> Result<()> {
let start_time = std::time::Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
if method == Method::TRACE {
warn!("[H2 HTTP] Blocked TRACE method (XST prevention)");
send_h2_error(
&mut respond,
StatusCode::METHOD_NOT_ALLOWED,
"TRACE method not allowed",
)
.await?;
return Ok(());
}
let target_url = uri.to_string();
if uri.scheme().is_none() || uri.authority().is_none() {
warn!(
"[H2 HTTP] Invalid proxy request URI (missing scheme/authority): {}",
target_url
);
send_h2_error(
&mut respond,
StatusCode::BAD_REQUEST,
"Proxy requests must use absolute-form URI",
)
.await?;
return Ok(());
}
let scheme = uri.scheme_str().unwrap_or("http");
let authority = uri.authority().unwrap().as_str();
info!(
"[H2 HTTP] {} request for {} (scheme={}, authority={})",
method, target_url, scheme, authority
);
let claims = match config.jwt_validator.validate_request(&request) {
Ok(claims) => claims,
Err(e) => {
return handle_h2_auth_error(e, &target_url, start_time, &mut respond).await;
}
};
debug!(
"[H2 HTTP] Authenticated {} - user_id={}, token_id={}",
target_url, claims.user_id, claims.token_id
);
match config.rate_limiter.check_limit(&claims.token_id).await {
Ok(()) => {}
Err(e) => {
return handle_h2_rate_limit_error(
e,
&target_url,
&claims.token_id,
start_time,
&mut respond,
)
.await;
}
};
let request_headers = request.headers().clone();
let mut recv_stream = request.into_body();
let mut body_bytes = BytesMut::new();
while let Some(chunk) = recv_stream.data().await {
match chunk {
Ok(data) => {
recv_stream.flow_control().release_capacity(data.len()).ok();
body_bytes.extend_from_slice(&data);
}
Err(e) => {
error!("[H2 HTTP] Error reading request body: {}", e);
send_h2_error(
&mut respond,
StatusCode::BAD_REQUEST,
"Failed to read request body",
)
.await?;
return Ok(());
}
}
}
let client = match reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.redirect(reqwest::redirect::Policy::none())
.build()
{
Ok(c) => c,
Err(e) => {
error!("[H2 HTTP] Failed to create HTTP client: {}", e);
send_h2_error(
&mut respond,
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
)
.await?;
return Ok(());
}
};
let method_str = method.as_str();
let upstream_method = reqwest::Method::from_bytes(method_str.as_bytes())
.map_err(|e| anyhow::anyhow!("Invalid HTTP method: {}", e))?;
let mut upstream_req = client.request(upstream_method, &target_url);
if !body_bytes.is_empty() {
upstream_req = upstream_req.body(body_bytes.to_vec());
}
let hop_by_hop_headers = [
"connection",
"keep-alive",
"proxy-connection",
"proxy-authorization", "transfer-encoding",
"upgrade",
"te", ];
debug!("[H2 HTTP] Request headers received: {:?}", request_headers);
for (name, value) in request_headers.iter() {
let name_str = name.as_str();
if hop_by_hop_headers.contains(&name_str) {
continue;
}
if name_str.starts_with(':') {
continue;
}
if let Ok(header_value) = value.to_str() {
debug!(
"[H2 HTTP] Forwarding header: {}: {}",
name_str, header_value
);
upstream_req = upstream_req.header(name_str, header_value);
}
}
upstream_req = upstream_req.header("host", authority);
debug!("[H2 HTTP] Set Host header to: {}", authority);
let upstream_response = match upstream_req.send().await {
Ok(resp) => resp,
Err(e) => {
error!(
"[H2 HTTP] Failed to connect to upstream {}: {}",
target_url, e
);
config
.request_logger
.log_request(
claims.token_id.clone(),
claims.user_id as i32,
method.as_str().to_string(),
target_url.clone(),
Some(502),
0,
Some(start_time.elapsed().as_millis() as i64),
false,
false,
Some(format!("Connection failed: {}", e)),
)
.await;
send_h2_error(
&mut respond,
StatusCode::BAD_GATEWAY,
"Failed to connect to upstream",
)
.await?;
return Ok(());
}
};
let status = upstream_response.status();
let upstream_headers = upstream_response.headers().clone();
let response_bytes = match upstream_response.bytes().await {
Ok(bytes) => bytes,
Err(e) => {
error!("[H2 HTTP] Failed to read response body: {}", e);
send_h2_error(
&mut respond,
StatusCode::BAD_GATEWAY,
"Failed to read response body",
)
.await?;
return Ok(());
}
};
let total_bytes = response_bytes.len();
let duration = start_time.elapsed();
let status_code =
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let mut response_builder = Response::builder().status(status_code);
for (name, value) in &upstream_headers {
let name_str = name.as_str().to_lowercase();
if matches!(
name_str.as_str(),
"connection" | "keep-alive" | "te" | "trailer" | "transfer-encoding" | "upgrade"
) {
continue;
}
if let Ok(value_str) = value.to_str() {
response_builder = response_builder.header(name.as_str(), value_str);
}
}
let response = response_builder.body(()).unwrap();
let mut send_stream = match respond.send_response(response, false) {
Ok(stream) => stream,
Err(e) => {
error!("[H2 HTTP] Failed to send response headers: {}", e);
return Err(anyhow::anyhow!("Failed to send response: {}", e));
}
};
if !response_bytes.is_empty() {
if let Err(e) = send_stream.send_data(response_bytes.clone(), true) {
error!("[H2 HTTP] Failed to send response body: {}", e);
return Err(anyhow::anyhow!("Failed to send body: {}", e));
}
} else {
if let Err(e) = send_stream.send_data(Bytes::new(), true) {
error!("[H2 HTTP] Failed to close stream: {}", e);
}
}
config
.request_logger
.log_request(
claims.token_id.clone(),
claims.user_id as i32,
method.as_str().to_string(),
target_url.clone(),
Some(status.as_u16() as i32),
total_bytes as i64,
Some(duration.as_millis() as i64),
true,
false,
None,
)
.await;
info!(
"[H2 HTTP] Completed {} {} - user_id={}, token_id={}, status={}, duration={:?}, bytes={}",
method,
target_url,
claims.user_id,
claims.token_id,
status.as_u16(),
duration,
total_bytes
);
Ok(())
}
async fn handle_h2_connect(
request: Request<h2::RecvStream>,
mut respond: h2::server::SendResponse<Bytes>,
config: Arc<Config>,
) -> Result<()> {
let start_time = std::time::Instant::now();
let target_host = match request.uri().authority() {
Some(auth) => auth.to_string(),
None => {
warn!("[H2] Missing authority in CONNECT request");
send_h2_error(
&mut respond,
StatusCode::BAD_REQUEST,
"Bad Request: CONNECT requires a valid host:port authority",
)
.await?;
return Ok(());
}
};
let (_host, _port) = match parse_authority(&target_host) {
Ok((h, p)) => (h, p),
Err(err_msg) => {
warn!("[H2] Invalid authority {}: {}", target_host, err_msg);
send_h2_error(
&mut respond,
StatusCode::BAD_REQUEST,
&format!("Bad Request: {}", err_msg),
)
.await?;
return Ok(());
}
};
let claims = match config.jwt_validator.validate_request(&request) {
Ok(claims) => claims,
Err(e) => {
return handle_h2_auth_error(e, &target_host, start_time, &mut respond).await;
}
};
info!(
"[H2 CONNECT] Authenticated {} - user_id={}, token_id={}, regions={:?}",
target_host, claims.user_id, claims.token_id, claims.allowed_regions
);
match config.rate_limiter.check_limit(&claims.token_id).await {
Ok(()) => {}
Err(e) => {
return handle_h2_rate_limit_error(
e,
&target_host,
&claims.token_id,
start_time,
&mut respond,
)
.await;
}
};
let upstream = match tokio::net::TcpStream::connect(&target_host).await {
Ok(stream) => stream,
Err(e) => {
let duration = start_time.elapsed();
error!(
"[H2 CONNECT] Failed to connect to {} - user_id={}, token_id={}, error={}, duration={:?}",
target_host, claims.user_id, claims.token_id, e, duration
);
send_h2_error(
&mut respond,
StatusCode::BAD_GATEWAY,
"Failed to connect to upstream server",
)
.await?;
return Ok(());
}
};
info!(
"[H2 CONNECT] Connected to {} - user_id={}, token_id={}",
target_host, claims.user_id, claims.token_id
);
let response = Response::builder().status(StatusCode::OK).body(()).unwrap();
let send_stream = match respond.send_response(response, false) {
Ok(stream) => stream,
Err(e) => {
error!("[H2 CONNECT] Failed to send response: {}", e);
return Err(anyhow::anyhow!("Failed to send 200 response: {}", e));
}
};
let recv_stream = request.into_body();
info!(
"[H2 CONNECT] Starting tunnel for {} - user_id={}, token_id={}",
target_host, claims.user_id, claims.token_id
);
let config_for_tunnel = Arc::clone(&config);
tokio::spawn(async move {
match tunnel_h2_streams(
recv_stream,
send_stream,
upstream,
target_host.clone(),
claims.user_id,
claims.token_id.clone(),
start_time,
)
.await
{
Ok((bytes_sent, bytes_received)) => {
let duration = start_time.elapsed();
let total_bytes = bytes_sent + bytes_received;
info!(
"[H2 CONNECT] Completed {} - user_id={}, token_id={}, duration={:?}, \
client→upstream={} bytes, upstream→client={} bytes, total={} bytes",
target_host,
claims.user_id,
claims.token_id,
duration,
bytes_sent,
bytes_received,
total_bytes
);
config_for_tunnel
.request_logger
.log_request(
claims.token_id.clone(),
claims.user_id,
"CONNECT".to_string(),
target_host.clone(),
Some(200), total_bytes as i64,
Some(duration.as_millis() as i64),
true, false, None, )
.await;
}
Err(e) => {
error!(
"[H2 CONNECT] Tunnel error for {} - user_id={}, token_id={}, error={}",
target_host, claims.user_id, claims.token_id, e
);
}
}
});
Ok(())
}
async fn send_h2_error(
respond: &mut h2::server::SendResponse<Bytes>,
status: StatusCode,
message: &str,
) -> Result<()> {
let response = Response::builder()
.status(status)
.header("content-type", "text/plain")
.body(())
.unwrap();
let mut send_stream = respond
.send_response(response, false)
.map_err(|e| anyhow::anyhow!("Failed to send error response: {}", e))?;
let body = Bytes::from(message.to_string());
send_stream
.send_data(body, true)
.map_err(|e| anyhow::anyhow!("Failed to send error body: {}", e))?;
Ok(())
}
async fn handle_h2_auth_error(
error: AuthError,
target_host: &str,
start_time: std::time::Instant,
respond: &mut h2::server::SendResponse<Bytes>,
) -> Result<()> {
let duration = start_time.elapsed();
let (status, message) = match error {
AuthError::MissingHeader => {
warn!(
"[H2 CONNECT] Missing Proxy-Authorization for {} (duration={:?})",
target_host, duration
);
(StatusCode::PROXY_AUTHENTICATION_REQUIRED, "Proxy authentication required. Please provide a valid Bearer token in the Proxy-Authorization header.")
}
AuthError::InvalidFormat => {
warn!(
"[H2 CONNECT] Invalid auth format for {} (duration={:?})",
target_host, duration
);
(
StatusCode::BAD_REQUEST,
"Invalid Proxy-Authorization format. Expected: Bearer <token>",
)
}
AuthError::TokenExpired => {
warn!(
"[H2 CONNECT] Expired token for {} (duration={:?})",
target_host, duration
);
(
StatusCode::PROXY_AUTHENTICATION_REQUIRED,
"Token expired. Please obtain a new authentication token.",
)
}
AuthError::ValidationFailed(ref msg) => {
warn!(
"[H2 CONNECT] Token validation failed for {}: {} (duration={:?})",
target_host, msg, duration
);
(
StatusCode::FORBIDDEN,
"Token validation failed. The provided token is invalid.",
)
}
AuthError::RegionNotAllowed(ref msg) => {
warn!(
"[H2 CONNECT] Region not allowed for {}: {} (duration={:?})",
target_host, msg, duration
);
(
StatusCode::FORBIDDEN,
"Access denied: Region not in allowed list for this token.",
)
}
};
let mut response = Response::builder()
.status(status)
.header("content-type", "text/plain");
if status == StatusCode::PROXY_AUTHENTICATION_REQUIRED {
response = response.header(
"proxy-authenticate",
"Basic realm=\"ProbeOps Forward Proxy\"",
);
}
let response = response.body(()).unwrap();
let mut send_stream = respond
.send_response(response, false)
.map_err(|e| anyhow::anyhow!("Failed to send auth error response: {}", e))?;
let body = Bytes::from(message.to_string());
send_stream
.send_data(body, true)
.map_err(|e| anyhow::anyhow!("Failed to send auth error body: {}", e))?;
Ok(())
}
async fn handle_h2_rate_limit_error(
error: RateLimitError,
target_host: &str,
token_id: &str,
start_time: std::time::Instant,
respond: &mut h2::server::SendResponse<Bytes>,
) -> Result<()> {
let duration = start_time.elapsed();
let (status, message): (StatusCode, String) = match error {
RateLimitError::LimitExceeded(_) => {
warn!(
"[H2 CONNECT] Rate limit exceeded for {} - token_id={} (duration={:?})",
target_host, token_id, duration
);
(
StatusCode::TOO_MANY_REQUESTS,
"Rate limit exceeded. Please retry after a short delay.".to_string(),
)
}
RateLimitError::TooManyTokens(max) => {
error!(
"[H2 CONNECT] Too many tokens for {} - max={} (duration={:?})",
target_host, max, duration
);
(
StatusCode::SERVICE_UNAVAILABLE,
format!(
"Service temporarily unavailable. Maximum {} concurrent tokens reached.",
max
),
)
}
};
let response = Response::builder()
.status(status)
.header("content-type", "text/plain")
.body(())
.unwrap();
let mut send_stream = respond
.send_response(response, false)
.map_err(|e| anyhow::anyhow!("Failed to send rate limit error response: {}", e))?;
let body = Bytes::from(message);
send_stream
.send_data(body, true)
.map_err(|e| anyhow::anyhow!("Failed to send rate limit error body: {}", e))?;
Ok(())
}
async fn send_with_flow_control(
send_stream: &mut h2::SendStream<Bytes>,
data: Bytes,
len: usize,
) -> Result<()> {
use tokio::time::{sleep, Duration};
send_stream.reserve_capacity(len);
let mut backoff = Duration::from_millis(10);
let max_backoff = Duration::from_millis(500);
let mut attempts = 0;
let max_attempts = 100;
loop {
let available = send_stream.capacity();
if available >= len {
send_stream
.send_data(data, false)
.map_err(|e| anyhow::anyhow!("Failed to send data: {}", e))?;
return Ok(());
}
if available == 0 && attempts > 10 {
match send_stream.send_data(Bytes::new(), false) {
Ok(_) => {
}
Err(e) => {
return Err(anyhow::anyhow!("Stream closed or reset: {}", e));
}
}
}
if attempts >= max_attempts {
return Err(anyhow::anyhow!(
"Flow control timeout: needed {} bytes, available {} after {} attempts",
len,
available,
attempts
));
}
if attempts % 10 == 0 {
debug!(
"Waiting for capacity: need={}, available={}, attempt={}, backoff={:?}",
len, available, attempts, backoff
);
}
sleep(backoff).await;
backoff = (backoff * 2).min(max_backoff);
attempts += 1;
}
}
async fn tunnel_h2_streams(
mut recv_stream: h2::RecvStream,
mut send_stream: h2::SendStream<Bytes>,
upstream: TcpStream,
target_host: String,
user_id: i32,
token_id: String,
start_time: std::time::Instant,
) -> Result<(u64, u64)> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let (mut upstream_read, mut upstream_write) = upstream.into_split();
let mut bytes_client_to_upstream = 0u64;
let mut bytes_upstream_to_client = 0u64;
let mut upstream_buf = vec![0u8; 16384];
loop {
tokio::select! {
result = recv_stream.data() => {
match result {
Some(Ok(data)) => {
let len = data.len();
upstream_write.write_all(&data).await?;
bytes_client_to_upstream += len as u64;
let _ = recv_stream.flow_control().release_capacity(len);
}
Some(Err(e)) => {
error!(
"[H2 TUNNEL] RecvStream error for {} - user_id={}, token_id={}, error={}",
target_host, user_id, token_id, e
);
break;
}
None => {
debug!(
"[H2 TUNNEL] Client closed stream for {} - user_id={}, token_id={}",
target_host, user_id, token_id
);
break;
}
}
}
result = upstream_read.read(&mut upstream_buf) => {
match result {
Ok(0) => {
debug!(
"[H2 TUNNEL] Upstream closed for {} - user_id={}, token_id={}",
target_host, user_id, token_id
);
break;
}
Ok(n) => {
let data = Bytes::copy_from_slice(&upstream_buf[..n]);
if let Err(e) = send_with_flow_control(&mut send_stream, data, n).await {
error!(
"[H2 TUNNEL] SendStream error for {} - user_id={}, token_id={}, error={}",
target_host, user_id, token_id, e
);
break;
}
bytes_upstream_to_client += n as u64;
}
Err(e) => {
error!(
"[H2 TUNNEL] Upstream read error for {} - user_id={}, token_id={}, error={}",
target_host, user_id, token_id, e
);
break;
}
}
}
}
}
let _ = send_stream.send_data(Bytes::new(), true);
let duration = start_time.elapsed();
info!(
"[H2 TUNNEL] Closed {} - user_id={}, token_id={}, duration={:?}, \
client→upstream={} bytes, upstream→client={} bytes",
target_host,
user_id,
token_id,
duration,
bytes_client_to_upstream,
bytes_upstream_to_client
);
Ok((bytes_client_to_upstream, bytes_upstream_to_client))
}
pub async fn serve_http1(tls_stream: TlsStream<TcpStream>, config: Arc<Config>) -> Result<()> {
info!("HTTP/1.1 connection handler started");
let io = TokioIo::new(tls_stream);
let conn = hyper::server::conn::http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| {
let config = Arc::clone(&config);
async move { handle_request(req, config).await }
}),
)
.with_upgrades();
if let Err(e) = conn.await {
error!("HTTP/1.1 connection error: {}", e);
}
Ok(())
}
async fn handle_request(
req: Request<Incoming>,
config: Arc<Config>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
let method = req.method().clone();
let uri = req.uri().clone();
debug!("Received {} request for {}", method, uri);
if method == Method::CONNECT {
handle_connect(req, config).await
} else {
if config.http_proxy_enabled {
forward_http_request(req, config).await
} else {
info!(
"[HTTP] Non-CONNECT {} request for {} - HTTP forwarding disabled, returning 204",
method, uri
);
Ok(Response::builder()
.status(StatusCode::NO_CONTENT)
.header("Connection", "close")
.body(Full::new(Bytes::new()))
.unwrap())
}
}
}
async fn forward_http_request(
req: Request<Incoming>,
config: Arc<Config>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
let start_time = std::time::Instant::now();
let method = req.method().clone();
let uri = req.uri().clone();
let target_url = uri.to_string();
if method == Method::TRACE {
warn!("[HTTP] Blocked TRACE method");
return Ok(Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.header("Allow", "GET, HEAD, POST, PUT, DELETE, PATCH, OPTIONS")
.header("Connection", "close")
.body(Full::new(Bytes::from("TRACE method not allowed")))
.unwrap());
}
if uri.scheme().is_none() || uri.authority().is_none() {
warn!(
"[HTTP] Invalid URI (missing scheme/authority): {}",
target_url
);
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Connection", "close")
.body(Full::new(Bytes::from(
"Proxy requests must use absolute-form URI",
)))
.unwrap());
}
debug!("[HTTP] {} {}", method, target_url);
let claims = match config.jwt_validator.validate_request(&req) {
Ok(claims) => claims,
Err(AuthError::MissingHeader) => {
return Ok(Response::builder()
.status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
.header("Proxy-Authenticate", "Bearer")
.header("Connection", "close")
.body(Full::new(Bytes::from("Proxy authentication required")))
.unwrap());
}
Err(e) => {
warn!("[HTTP] Auth failed: {}", e);
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.header("Connection", "close")
.body(Full::new(Bytes::from(format!(
"Authentication failed: {}",
e
))))
.unwrap());
}
};
let client_ip = req
.extensions()
.get::<std::net::SocketAddr>()
.map(|addr| addr.ip())
.unwrap_or_else(|| "127.0.0.1".parse().unwrap());
if let Err(e) = config
.ip_tracker
.check_and_track(&claims.token_id, client_ip)
.await
{
warn!("[HTTP] IP limit exceeded: {}", e);
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.header("Connection", "close")
.body(Full::new(Bytes::from(e.to_string())))
.unwrap());
}
if let Err(e) = config.rate_limiter.check_limit(&claims.token_id).await {
warn!("[HTTP] Rate limit exceeded: {}", e);
return Ok(Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("Connection", "close")
.body(Full::new(Bytes::from(format!(
"Rate limit exceeded: {}",
e
))))
.unwrap());
}
let request_headers = req.headers().clone();
let (mut parts, body) = req.into_parts();
match crate::mixed_content::handle_mixed_content_request(
&mut parts.uri,
&request_headers,
&config,
)
.await
{
Ok(Some(response)) => {
return Ok(response);
}
Ok(None) => {
}
Err(e) => {
error!("[HTTP] Mixed content policy error: {}", e);
}
}
let body_bytes =
match crate::body_limiter::read_body_with_limit(body, config.max_request_body_size).await {
Ok(bytes) => bytes,
Err(e) => {
warn!("[HTTP] Body limit error: {}", e);
return Ok(Response::builder()
.status(e.status_code())
.header("Connection", "close")
.body(Full::new(Bytes::from(e.to_response_message())))
.unwrap());
}
};
let uri = &parts.uri;
let scheme = uri.scheme_str().unwrap_or("http");
let authority = uri.authority().unwrap();
let host = authority.host();
let port = authority
.port_u16()
.unwrap_or(if scheme == "https" { 443 } else { 80 });
let path = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
let vetted_ips = match config.destination_filter.check_and_resolve(host).await {
Ok(ips) => ips,
Err(e) => {
warn!("[HTTP] Destination blocked: {}", e);
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.header("Connection", "close")
.body(Full::new(Bytes::from(format!(
"Destination not allowed: {}",
e
))))
.unwrap());
}
};
let body_option = if body_bytes.is_empty() {
None
} else {
Some(body_bytes.clone())
};
match crate::http_client::forward_request(
vetted_ips,
port,
scheme,
method.as_str(),
path,
host,
&request_headers,
body_option,
&config,
)
.await
{
Ok((status, headers, body)) => {
let mut response = Response::builder()
.status(status)
.header("Connection", "close");
for (name, value) in &headers {
response = response.header(name, value);
}
config
.request_logger
.log_request(
claims.token_id.clone(),
claims.user_id as i32,
method.as_str().to_string(),
target_url.clone(),
Some(status.as_u16() as i32),
body.len() as i64,
Some(start_time.elapsed().as_millis() as i64),
true,
false,
None,
)
.await;
Ok(response.body(Full::new(body)).unwrap())
}
Err(e) => {
error!("[HTTP] Forward failed: {}", e);
let error_msg = match e {
crate::http_client::HttpClientError::ResponseTooLarge { size, limit } => {
format!(
"Upstream response too large: {} bytes (limit: {})",
size, limit
)
}
_ => format!("Failed to connect to upstream: {}", e),
};
config
.request_logger
.log_request(
claims.token_id.clone(),
claims.user_id as i32,
method.as_str().to_string(),
target_url.clone(),
Some(502),
0,
Some(start_time.elapsed().as_millis() as i64),
false,
false,
Some(error_msg.clone()),
)
.await;
Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.header("Connection", "close")
.body(Full::new(Bytes::from(error_msg)))
.unwrap())
}
}
}
async fn handle_connect<B>(
mut req: Request<B>,
config: Arc<Config>,
) -> Result<Response<Full<Bytes>>, hyper::Error>
where
B: hyper::body::Body + Send + 'static,
B::Data: Send,
B::Error: std::error::Error + Send + Sync,
{
let start_time = std::time::Instant::now();
let target_host = match req.uri().authority() {
Some(auth) => auth.to_string(),
None => {
warn!("[CONNECT] Missing authority in CONNECT request");
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from(
"Bad Request: CONNECT requires a valid host:port authority",
)))
.unwrap());
}
};
let (_host, _port) = match parse_authority(&target_host) {
Ok((h, p)) => (h, p),
Err(err_msg) => {
warn!("[CONNECT] Invalid authority {}: {}", target_host, err_msg);
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from(format!("Bad Request: {}", err_msg))))
.unwrap());
}
};
info!("[CONNECT] {} from client", target_host);
let claims = match config.jwt_validator.validate_request(&req) {
Ok(claims) => claims,
Err(e) => {
return handle_auth_error(e, &target_host, start_time);
}
};
debug!(
"[CONNECT] Authenticated user_id={}, token_id={}, allowed_regions={:?}",
claims.user_id, claims.token_id, claims.allowed_regions
);
match config.rate_limiter.check_limit(&claims.token_id).await {
Ok(()) => {}
Err(e) => {
return handle_rate_limit_error(e, &target_host, &claims.token_id, start_time);
}
};
debug!(
"[CONNECT] Rate limit check passed for token_id={}",
claims.token_id
);
let upstream = match tokio::net::TcpStream::connect(&target_host).await {
Ok(stream) => stream,
Err(e) => {
error!("[CONNECT] Failed to connect to {}: {}", target_host, e);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Full::new(Bytes::from("Failed to connect to target")))
.unwrap());
}
};
info!("[CONNECT] Connected to upstream {}", target_host);
let target_host_for_log = target_host.clone();
let config_for_tunnel = Arc::clone(&config);
tokio::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
if let Err(e) = tunnel(
upgraded,
upstream,
target_host.clone(),
claims.user_id as i64,
claims.token_id,
start_time,
config_for_tunnel,
)
.await
{
error!("[CONNECT] Tunnel error for {}: {}", target_host, e);
}
}
Err(e) => {
error!("[CONNECT] Upgrade error for {}: {}", target_host, e);
}
}
});
info!(
"[CONNECT] Sending 200 Connection Established for {}",
target_host_for_log
);
Ok(Response::builder()
.status(StatusCode::OK)
.body(Full::new(Bytes::new()))
.unwrap())
}
async fn tunnel(
upgraded: Upgraded,
upstream: TcpStream,
target_host: String,
user_id: i64,
token_id: String,
start_time: std::time::Instant,
config: Arc<Config>,
) -> Result<()> {
let client = TokioIo::new(upgraded);
let (mut client_read, mut client_write) = tokio::io::split(client);
let (mut upstream_read, mut upstream_write) = tokio::io::split(upstream);
let client_to_upstream = tokio::io::copy(&mut client_read, &mut upstream_write);
let upstream_to_client = tokio::io::copy(&mut upstream_read, &mut client_write);
let (c_to_u, u_to_c) = tokio::try_join!(client_to_upstream, upstream_to_client)?;
let duration = start_time.elapsed();
let total_bytes = c_to_u + u_to_c;
info!(
"[CONNECT] Completed {} - user_id={}, token_id={}, duration={:?}, \
client→upstream={} bytes, upstream→client={} bytes, total={} bytes",
target_host, user_id, token_id, duration, c_to_u, u_to_c, total_bytes
);
config
.request_logger
.log_request(
token_id.clone(),
user_id as i32,
"CONNECT".to_string(),
target_host.clone(),
Some(200), total_bytes as i64,
Some(duration.as_millis() as i64),
true, false, None, )
.await;
Ok(())
}
fn handle_auth_error(
error: AuthError,
target_host: &str,
start_time: std::time::Instant,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
let duration = start_time.elapsed();
let (status, message) = match error {
AuthError::MissingHeader => {
warn!(
"[CONNECT] Authentication failed for {}: missing header (duration={:?})",
target_host, duration
);
(
StatusCode::PROXY_AUTHENTICATION_REQUIRED,
"Proxy authentication required. Please provide a valid JWT token in the Proxy-Authorization or Authorization header."
)
}
AuthError::InvalidFormat => {
warn!(
"[CONNECT] Authentication failed for {}: invalid format (duration={:?})",
target_host, duration
);
(
StatusCode::BAD_REQUEST,
"Invalid authentication format. Expected: Proxy-Authorization: Bearer <token>",
)
}
AuthError::TokenExpired => {
warn!(
"[CONNECT] Authentication failed for {}: token expired (duration={:?})",
target_host, duration
);
(
StatusCode::PROXY_AUTHENTICATION_REQUIRED,
"JWT token has expired. Please refresh your token or login again to the ProbeOps platform."
)
}
AuthError::ValidationFailed(ref msg) => {
warn!(
"[CONNECT] Authentication failed for {}: {} (duration={:?})",
target_host, msg, duration
);
(
StatusCode::FORBIDDEN,
"Authentication failed: Invalid JWT token",
)
}
AuthError::RegionNotAllowed(ref msg) => {
warn!(
"[CONNECT] Authentication failed for {}: {} (duration={:?})",
target_host, msg, duration
);
(
StatusCode::FORBIDDEN,
"Access denied: Region not in allowed list",
)
}
};
let mut response = Response::builder()
.status(status)
.body(Full::new(Bytes::from(message)))
.unwrap();
if status == StatusCode::PROXY_AUTHENTICATION_REQUIRED {
response.headers_mut().insert(
"Proxy-Authenticate",
"Basic realm=\"ProbeOps Forward Proxy\"".parse().unwrap(),
);
}
Ok(response)
}
fn handle_rate_limit_error(
error: RateLimitError,
target_host: &str,
token_id: &str,
start_time: std::time::Instant,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
let duration = start_time.elapsed();
match error {
RateLimitError::LimitExceeded(_) => {
warn!(
"[CONNECT] Rate limit exceeded for {} - token_id={} (duration={:?})",
target_host, token_id, duration
);
Ok(Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body(Full::new(Bytes::from(
"Rate limit exceeded. Please retry later.",
)))
.unwrap())
}
RateLimitError::TooManyTokens(max) => {
error!(
"[CONNECT] Too many tokens error for {} - max={} (duration={:?})",
target_host, max, duration
);
Ok(Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.body(Full::new(Bytes::from(format!(
"Service temporarily unavailable. Maximum {} concurrent tokens.",
max
))))
.unwrap())
}
}
}