use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::header::HeaderName;
use hyper::{Method, Request, Response, Uri};
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::rt::TokioExecutor;
use tokio::time::timeout;
use tracing::{Instrument, debug, info, warn};
use crate::{IpRateLimiter, LoadBalancer, ProxyError, Result, RuntimeConfig, headers, tls};
type StdError = Box<dyn std::error::Error + Send + Sync>;
pub type BoxBody = http_body_util::combinators::BoxBody<Bytes, StdError>;
pub type HttpClient = Client<HttpConnector, BoxBody>;
pub type HttpsClient = Client<hyper_rustls::HttpsConnector<HttpConnector>, BoxBody>;
static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
pub fn build_client(config: &RuntimeConfig) -> HttpClient {
Client::builder(TokioExecutor::new())
.pool_idle_timeout(config.pool_idle_timeout)
.pool_max_idle_per_host(config.pool_max_idle_per_host)
.build(HttpConnector::new())
}
pub fn build_https_client(config: &RuntimeConfig) -> HttpsClient {
let connector = tls::build_https_connector();
Client::builder(TokioExecutor::new())
.pool_idle_timeout(config.pool_idle_timeout)
.pool_max_idle_per_host(config.pool_max_idle_per_host)
.build(connector)
}
pub async fn handle_request<B, C>(
req: Request<B>,
client: Client<C, BoxBody>,
config: Arc<RuntimeConfig>,
balancer: LoadBalancer,
client_addr: SocketAddr,
rate_limiter: Option<&IpRateLimiter>,
) -> Result<Response<BoxBody>>
where
B: hyper::body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<StdError>,
C: hyper_util::client::legacy::connect::Connect + Clone + Send + Sync + 'static,
{
let request_id = REQUEST_ID.fetch_add(1, Ordering::Relaxed);
let method = req.method().clone();
let uri = req.uri().clone();
let span = tracing::info_span!(
"request",
id = request_id,
method = %method,
uri = %uri,
client = %client_addr,
);
async move {
if let Some(limiter) = rate_limiter {
limiter.check(&client_addr.ip()).map_err(|retry_after_ms| {
warn!(
ip = %client_addr.ip(),
retry_after_ms,
"rate limit exceeded"
);
ProxyError::RateLimited { retry_after_ms }
})?;
}
if headers::is_smuggling_attempt(req.headers()) {
warn!("request smuggling attempt detected");
return Err(ProxyError::RequestSmuggling);
}
if headers::content_length_exceeds(req.headers(), config.max_body_size) {
let declared = req
.headers()
.get(hyper::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
warn!(
content_length = declared,
limit = config.max_body_size,
"request body exceeds size limit"
);
return Err(ProxyError::BodyTooLarge {
limit: config.max_body_size,
});
}
if method == Method::GET {
inspect_get_request(&req, &config)?;
}
let upstream = balancer.next()?;
let upstream_uri_target = upstream.uri();
info!(upstream = %upstream_uri_target, "received request");
let rewritten_uri = rewrite_uri(&uri, upstream_uri_target)?;
let (mut parts, body) = req.into_parts();
headers::strip_hop_by_hop(&mut parts.headers);
headers::inject_forwarding_headers(&mut parts.headers, client_addr);
headers::rewrite_host(
&mut parts.headers,
upstream_uri_target
.authority()
.ok_or_else(|| ProxyError::InvalidUpstream("upstream has no authority".into()))?,
);
config.blocked_headers.iter().for_each(|blocked| {
if let Ok(name) = HeaderName::from_bytes(blocked.as_bytes()) {
parts.headers.remove(&name);
}
});
parts.uri = rewritten_uri;
debug!(
headers = ?parts.headers,
upstream_uri = %parts.uri,
"forwarding request"
);
let start = std::time::Instant::now();
let boxed_body = body.map_err(|e| e.into()).boxed();
let proxy_req = Request::from_parts(parts, boxed_body);
let upstream_result = timeout(config.request_timeout, client.request(proxy_req)).await;
let mut upstream_resp = match upstream_result {
Ok(Ok(resp)) => {
upstream.record_success(config.healthy_threshold);
resp
}
Ok(Err(e)) => {
let transitioned = upstream.record_failure(config.failure_threshold);
warn!(
error = %e,
latency_ms = start.elapsed().as_millis() as u64,
upstream = %upstream_uri_target,
marked_unhealthy = transitioned,
"upstream request failed"
);
return Err(ProxyError::Upstream(e));
}
Err(_elapsed) => {
let transitioned = upstream.record_failure(config.failure_threshold);
warn!(
timeout = ?config.request_timeout,
latency_ms = start.elapsed().as_millis() as u64,
upstream = %upstream_uri_target,
marked_unhealthy = transitioned,
"upstream request timed out"
);
return Err(ProxyError::Timeout(config.request_timeout));
}
};
let latency_ms = start.elapsed().as_millis() as u64;
info!(
status = upstream_resp.status().as_u16(),
latency_ms,
upstream = %upstream_uri_target,
"upstream responded"
);
headers::strip_hop_by_hop(upstream_resp.headers_mut());
if !config.strip_response_headers.is_empty() {
headers::strip_response_headers(
upstream_resp.headers_mut(),
&config.strip_response_headers,
);
}
let mut resp = build_response(upstream_resp, &config).await?;
resp.headers_mut().insert(
HeaderName::from_static("x-request-id"),
hyper::header::HeaderValue::from(request_id),
);
Ok(resp)
}
.instrument(span)
.await
}
fn inspect_get_request<B>(req: &Request<B>, config: &RuntimeConfig) -> Result<()> {
let headers = req.headers();
config
.blocked_headers
.iter()
.find(|blocked| {
HeaderName::from_bytes(blocked.as_bytes())
.ok()
.is_some_and(|name| headers.contains_key(&name))
})
.map_or(Ok(()), |name| {
warn!(header = %name, "blocked header detected");
Err(ProxyError::BlockedHeader(name.clone()))
})?;
let query = req.uri().query().unwrap_or_default();
config
.blocked_params
.iter()
.find(|param| query.contains(&format!("{param}=")))
.map_or(Ok(()), |name| {
warn!(param = %name, "blocked parameter detected");
Err(ProxyError::BlockedParam(name.clone()))
})
}
fn rewrite_uri(original: &Uri, upstream: &Uri) -> Result<Uri> {
let authority = upstream
.authority()
.ok_or_else(|| ProxyError::InvalidUpstream("upstream has no authority".into()))?;
let scheme = upstream
.scheme()
.ok_or_else(|| ProxyError::InvalidUpstream("upstream has no scheme".into()))?;
let path_and_query = original
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
Uri::builder()
.scheme(scheme.clone())
.authority(authority.clone())
.path_and_query(path_and_query)
.build()
.map_err(|e| ProxyError::Internal(format!("failed to build upstream URI: {e}")))
}
async fn build_response(
upstream_resp: Response<Incoming>,
config: &RuntimeConfig,
) -> Result<Response<BoxBody>> {
if config.mask_rules.is_empty() {
let (parts, body) = upstream_resp.into_parts();
return Ok(Response::from_parts(
parts,
body.map_err(|e| -> StdError { Box::new(e) }).boxed(),
));
}
let should_mask = upstream_resp
.headers()
.get(hyper::header::CONTENT_TYPE)
.and_then(|ct| ct.to_str().ok())
.is_some_and(|ct| ct.contains("text/") || ct.contains("application/x-www-form-urlencoded"));
if !should_mask {
let (parts, body) = upstream_resp.into_parts();
return Ok(Response::from_parts(
parts,
body.map_err(|e| -> StdError { Box::new(e) }).boxed(),
));
}
let (parts, body) = upstream_resp.into_parts();
let body_bytes = body
.collect()
.await
.map_err(|e| ProxyError::Internal(format!("failed to read upstream body: {e}")))?
.to_bytes();
let body_str = String::from_utf8_lossy(&body_bytes);
let masked = config.mask_sensitive_data(&body_str);
let mut response = Response::new(
Full::new(Bytes::from(masked))
.map_err(|never| -> StdError { match never {} })
.boxed(),
);
*response.status_mut() = parts.status;
*response.headers_mut() = parts.headers;
*response.version_mut() = parts.version;
Ok(response)
}
#[cfg(test)]
mod tests {
use http_body_util::Empty;
use super::*;
use crate::Config;
use crate::config::UpstreamConfig;
fn parse_uri(uri: &str) -> Uri {
uri.parse::<Uri>().expect("failed to parse URI")
}
fn single_upstream(addr: &str) -> Vec<UpstreamConfig> {
vec![UpstreamConfig {
address: addr.into(),
weight: 1,
}]
}
#[test]
fn rewrite_uri_preserves_path_and_query() {
let original = parse_uri("http://client-facing.com/api/v1?key=val");
let upstream = parse_uri("http://localhost:3000");
let result = rewrite_uri(&original, &upstream).unwrap();
assert_eq!(result.scheme_str(), Some("http"));
assert_eq!(result.authority().unwrap().as_str(), "localhost:3000");
assert_eq!(result.path_and_query().unwrap().as_str(), "/api/v1?key=val");
}
#[test]
fn rewrite_uri_defaults_to_root_path() {
let original = parse_uri("http://client-facing.com");
let upstream = parse_uri("http://localhost:3000");
let result = rewrite_uri(&original, &upstream).unwrap();
assert_eq!(result.path_and_query().unwrap().as_str(), "/");
}
#[test]
fn inspect_get_detects_blocked_header() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
blocked_headers: vec!["x-bad-header".into()],
..Default::default()
}
.into_runtime()
.unwrap();
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com/")
.header("x-bad-header", "anything")
.body(Empty::<Bytes>::new())
.unwrap();
let result = inspect_get_request(&req, &config);
assert!(result.is_err());
}
#[test]
fn inspect_get_detects_blocked_param() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
blocked_params: vec!["secret_key".into()],
..Default::default()
}
.into_runtime()
.unwrap();
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com/?secret_key=abc123")
.body(Empty::<Bytes>::new())
.unwrap();
let result = inspect_get_request(&req, &config);
assert!(result.is_err());
}
#[test]
fn inspect_get_allows_clean_request() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
blocked_headers: vec!["x-bad-header".into()],
blocked_params: vec!["secret_key".into()],
..Default::default()
}
.into_runtime()
.unwrap();
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com/path?safe=true")
.header("x-good-header", "ok")
.body(Empty::<Bytes>::new())
.unwrap();
assert!(inspect_get_request(&req, &config).is_ok());
}
}