use anyhow::Error;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::header;
use hyper::{Request, Response, StatusCode, Uri};
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use std::sync::Arc;
use tokio::time::{timeout, Duration};
use tracing::{error, info};
#[cfg(feature = "logging")]
use tracing::info_span;
#[cfg(feature = "logging")]
use tracing::Instrument;
use crate::config::{extract_hostname, Config, SiteConfig};
#[cfg(feature = "logging")]
use crate::proxy::access_log::AccessLogGuard;
use crate::proxy::access_log::{ensure_request_id, final_request_id};
use crate::proxy::ActionResult;
use crate::proxy::directives::{
handle_header, handle_method, handle_redirect, handle_respond, handle_reverse_proxy,
handle_strip_prefix, handle_uri_replace,
};
type ResponseBody =
http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
fn is_hop_header(name: &header::HeaderName) -> bool {
matches!(
name,
&header::CONNECTION
| &header::UPGRADE
| &header::TE
| &header::TRAILER
| &header::PROXY_AUTHENTICATE
| &header::PROXY_AUTHORIZATION
)
}
pub fn process_directives(
directives: &[crate::config::Directive],
req: &mut Request<Incoming>,
current_path: &str,
) -> Result<ActionResult, String> {
let mut modified_path = current_path.to_string();
for directive in directives {
match directive {
crate::config::Directive::Header { name, value } => {
if let Err(e) = handle_header(name, value.as_deref(), req) {
info!(" Failed to apply header {}: {}", name, e);
}
}
crate::config::Directive::UriReplace { find, replace } => {
handle_uri_replace(find, replace, &mut modified_path);
}
crate::config::Directive::StripPrefix { prefix } => {
handle_strip_prefix(prefix, &mut modified_path);
}
crate::config::Directive::HandlePath {
pattern,
directives: nested_directives,
} => {
if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
info!(" Matched handle_path: {}", pattern);
return process_directives(nested_directives, req, &remaining_path);
}
}
crate::config::Directive::Method {
methods,
directives: nested_directives,
} => {
if handle_method(methods, req) {
info!(" Matched method directive");
return process_directives(nested_directives, req, &modified_path);
}
}
crate::config::Directive::Redirect { status, url } => {
return Ok(handle_redirect(status, url));
}
crate::config::Directive::Respond { status, body } => {
return Ok(handle_respond(status, body));
}
crate::config::Directive::ReverseProxy {
to,
connect_timeout,
read_timeout,
} => {
return Ok(handle_reverse_proxy(
to,
&modified_path,
*connect_timeout,
*read_timeout,
));
}
}
}
Err(format!(
"No action directive (respond or reverse_proxy) found in configuration for path: {}",
current_path
))
}
pub async fn proxy(
mut req: Request<Incoming>,
client: Client<HttpsConnector<HttpConnector>, Incoming>,
config: Arc<Config>,
remote_addr: std::net::SocketAddr,
is_tls: bool,
) -> Result<Response<ResponseBody>, Error> {
let initial_request_id = ensure_request_id(&mut req);
#[cfg(feature = "logging")]
let method = req.method().clone().to_string();
let path = req.uri().path().to_string();
let host = req
.headers()
.get(hyper::header::HOST)
.and_then(|h| h.to_str().ok())
.unwrap_or("localhost")
.to_string();
#[cfg(feature = "logging")]
let span = info_span!("request", req_id = %initial_request_id);
#[allow(unused_variables)]
let future = async move {
#[cfg(feature = "logging")]
let mut log_guard = AccessLogGuard::new(
initial_request_id.clone(),
remote_addr,
method,
path.clone(),
host.clone(),
);
let site_config = match find_site(&config, &host, is_tls) {
Some(config) => config,
None => {
error!("No configuration found for host: {}", host);
let (response, _body_len) = error_response_with_id(
StatusCode::NOT_FOUND,
&format!("No configuration found for host: {}", host),
&initial_request_id,
);
#[cfg(feature = "logging")]
{
log_guard.set_bytes_sent(_body_len);
log_guard.finish(404);
}
return Ok(response);
}
};
let action_result = match process_directives(&site_config.directives, &mut req, &path) {
Ok(result) => result,
Err(e) => {
error!("Directive processing error: {}", e);
let final_id = final_request_id(&req, &initial_request_id);
#[cfg(feature = "logging")]
{
log_guard.set_request_id(final_id.clone());
tracing::Span::current().record("req_id", final_id.as_str());
}
let (response, _body_len) =
error_response_with_id(StatusCode::INTERNAL_SERVER_ERROR, &e, &final_id);
#[cfg(feature = "logging")]
{
log_guard.set_bytes_sent(_body_len);
log_guard.finish(500);
}
return Ok(response);
}
};
let request_id = final_request_id(&req, &initial_request_id);
#[cfg(feature = "logging")]
{
log_guard.set_request_id(request_id.clone());
tracing::Span::current().record("req_id", request_id.as_str());
}
match action_result {
ActionResult::Redirect { status, url } => {
let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::FOUND);
let boxed: ResponseBody = Full::new(Bytes::new())
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.boxed();
let response = Response::builder()
.status(status_code)
.header("Location", &url)
.header("X-Request-ID", &request_id)
.body(boxed)?;
#[cfg(feature = "logging")]
{
log_guard.set_bytes_sent(0);
log_guard.finish(status_code.as_u16());
}
Ok(response)
}
ActionResult::Respond { status, body } => {
let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
let _body_len = body.len();
let boxed: ResponseBody = Full::new(Bytes::from(body))
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.boxed();
let response = Response::builder()
.status(status_code)
.header("X-Request-ID", &request_id)
.body(boxed)?;
#[cfg(feature = "logging")]
{
log_guard.set_bytes_sent(_body_len);
log_guard.finish(status_code.as_u16());
}
Ok(response)
}
ActionResult::ReverseProxy {
backend_url,
path_to_send,
connect_timeout: _,
read_timeout,
} => {
let backend_with_proto =
if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
backend_url
} else {
format!("http://{}", backend_url)
};
let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
parts.path_and_query = Some(path_to_send.parse()?);
let new_uri = Uri::from_parts(parts)?;
*req.uri_mut() = new_uri.clone();
let original_host_header = req.headers().get(hyper::header::HOST).cloned();
req.headers_mut().remove(hyper::header::HOST);
if let Some(authority) = new_uri.authority() {
if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>()
{
req.headers_mut().insert(hyper::header::HOST, host_value);
}
}
if let Some(host_value) = original_host_header.clone() {
req.headers_mut().insert("X-Forwarded-Host", host_value);
}
req.headers_mut().insert(
"X-Forwarded-Proto",
hyper::header::HeaderValue::from_static(if is_tls { "https" } else { "http" }),
);
if let Ok(ip_value) =
hyper::header::HeaderValue::from_str(&remote_addr.ip().to_string())
{
req.headers_mut().insert("X-Forwarded-For", ip_value);
}
req.headers_mut().remove(header::CONNECTION);
req.headers_mut().remove("accept-encoding");
let backend_timeout = read_timeout.unwrap_or(30);
match timeout(Duration::from_secs(backend_timeout), client.request(req)).await {
Ok(Ok(response)) => {
let status = response.status();
let headers = response.headers().clone();
let mut builder = Response::builder().status(status);
for (name, value) in headers.iter() {
if !is_hop_header(name) && name != header::CONTENT_LENGTH {
builder = builder.header(name, value);
}
}
let (_, incoming_body) = response.into_parts();
let boxed: ResponseBody = incoming_body
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.boxed();
let response = builder.header("X-Request-ID", &request_id).body(boxed)?;
#[cfg(feature = "logging")]
log_guard.finish(status.as_u16());
Ok(response)
}
Ok(Err(e)) => {
error!("Backend connection failed: {:?}", e);
if e.is_connect() {
error!(" Reason: Connection refused - backend unavailable");
} else {
error!(" Reason: Other connection error");
}
let (response, _body_len) = error_response_with_id(
StatusCode::BAD_GATEWAY,
"Backend service unavailable",
&request_id,
);
#[cfg(feature = "logging")]
{
log_guard.set_bytes_sent(_body_len);
log_guard.finish(502);
}
Ok(response)
}
Err(_) => {
error!(
"Backend request timed out after {} seconds",
backend_timeout
);
let (response, _body_len) = error_response_with_id(
StatusCode::GATEWAY_TIMEOUT,
"Backend request timed out",
&request_id,
);
#[cfg(feature = "logging")]
{
log_guard.set_bytes_sent(_body_len);
log_guard.finish(504);
}
Ok(response)
}
}
}
}
};
#[cfg(feature = "logging")]
let future = future.instrument(span);
future.await
}
fn error_response_with_id(
status: StatusCode,
message: &str,
request_id: &str,
) -> (Response<ResponseBody>, usize) {
let body = format!(
r#"<!DOCTYPE html>
<html>
<head><title>{} {}</title></head>
<body>
<h1>{} {}</h1>
<p>{}</p>
<hr>
<p><em>Rust Proxy Server</em></p>
</body>
</html>"#,
status.as_u16(),
status.canonical_reason().unwrap_or("Error"),
status.as_u16(),
status.canonical_reason().unwrap_or("Error"),
message
);
let body_len = body.len();
let full = Full::new(Bytes::from(body));
let boxed: ResponseBody = full
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.boxed();
let mut builder = Response::builder()
.status(status)
.header("Content-Type", "text/html; charset=utf-8");
if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
builder = builder.header("X-Request-ID", val);
}
let response = builder.body(boxed).unwrap_or_else(|_| {
Response::new(
Full::new(Bytes::from("Internal Server Error"))
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.boxed(),
)
});
(response, body_len)
}
pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
if let Some(prefix) = pattern.strip_suffix("/*") {
if path.starts_with(prefix) {
let remaining = path.strip_prefix(prefix).unwrap_or(path);
Some(remaining.to_string())
} else {
None
}
} else if pattern == path {
Some("/".to_string())
} else {
None
}
}
pub fn find_site<'a>(config: &'a Config, host: &str, is_tls: bool) -> Option<&'a SiteConfig> {
if let Some(site) = config.sites.get(host) {
return Some(site);
}
let has_port = if host.starts_with('[') {
if let Some(bracket_end) = host.find(']') {
host[bracket_end..].contains(':')
} else {
false
}
} else {
host.contains(':')
};
if !has_port {
let default_port = if is_tls { 443 } else { 80 };
let candidate = format!("{}:{}", host, default_port);
if let Some(site) = config.sites.get(&candidate) {
return Some(site);
}
if is_tls {
let mut matches = config.sites.values().filter(|s| {
s.tls.is_some() && extract_hostname(&s.address).eq_ignore_ascii_case(host)
});
if let Some(site) = matches.next() {
if matches.next().is_none() {
return Some(site);
}
}
}
} else {
let hostname = if host.starts_with('[') {
let end = host.find(']').unwrap_or(host.len());
host[1..end].to_string()
} else {
host.rsplit(':').next_back().unwrap_or(host).to_string()
};
if let Some(site) = config.sites.get(&hostname) {
return Some(site);
}
}
None
}
#[cfg(test)]
mod find_site_tests {
use super::*;
use std::collections::HashMap;
fn make_config(sites: Vec<(&str, bool)>) -> Config {
let mut map = HashMap::new();
for (addr, has_tls) in sites {
map.insert(
addr.to_string(),
crate::config::SiteConfig {
address: addr.to_string(),
directives: vec![],
tls: if has_tls {
Some(crate::config::TlsConfig {
cert_path: "/fake/cert.pem".to_string(),
key_path: "/fake/key.pem".to_string(),
})
} else {
None
},
},
);
}
Config { sites: map }
}
#[test]
fn test_exact_match() {
let config = make_config(vec![("example.com:443", true)]);
assert!(find_site(&config, "example.com:443", true).is_some());
}
#[test]
fn test_tls_host_without_port_finds_443() {
let config = make_config(vec![("example.com:443", true)]);
assert!(
find_site(&config, "example.com", true).is_some(),
"Should find example.com:443 when Host has no port and is_tls=true"
);
}
#[test]
fn test_http_host_without_port_finds_80() {
let config = make_config(vec![("example.com:80", false)]);
assert!(
find_site(&config, "example.com", false).is_some(),
"Should find example.com:80 when Host has no port and is_tls=false"
);
}
#[test]
fn test_tls_host_without_port_no_match_on_80() {
let config = make_config(vec![("example.com:80", false)]);
assert!(
find_site(&config, "example.com", true).is_none(),
"TLS on port 443 should not find :80 site"
);
}
#[test]
fn test_host_with_port_strips_port_fallback() {
let config = make_config(vec![("example.com", false)]);
assert!(
find_site(&config, "example.com:8080", false).is_some(),
"Should strip port from Host and find config without port"
);
}
#[test]
fn test_tls_host_without_port_finds_non_standard_port() {
let config = make_config(vec![("alpha.local:8443", true)]);
assert!(
find_site(&config, "alpha.local", true).is_some(),
"Should find alpha.local:8443 when Host has no port on TLS"
);
}
#[test]
fn test_no_match() {
let config = make_config(vec![("other.com:443", true)]);
assert!(find_site(&config, "example.com", true).is_none());
}
}