use anyhow::Error;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::header;
use hyper::{Request, Response, StatusCode, Uri};
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use std::sync::Arc;
use tokio::time::{timeout, Duration};
use tracing::{error, info};
use crate::config::Config;
use crate::proxy::ActionResult;
use crate::proxy::directives::{
handle_header, handle_method, handle_redirect, handle_respond, handle_reverse_proxy,
handle_strip_prefix, handle_uri_replace,
};
type ResponseBody =
http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
fn is_hop_header(name: &header::HeaderName) -> bool {
matches!(
name,
&header::CONNECTION
| &header::UPGRADE
| &header::TE
| &header::TRAILER
| &header::PROXY_AUTHENTICATE
| &header::PROXY_AUTHORIZATION
)
}
pub fn process_directives(
directives: &[crate::config::Directive],
req: &mut Request<Incoming>,
current_path: &str,
) -> Result<ActionResult, String> {
let mut modified_path = current_path.to_string();
for directive in directives {
match directive {
crate::config::Directive::Header { name, value } => {
if let Err(e) = handle_header(name, value.as_deref(), req) {
info!(" Failed to apply header {}: {}", name, e);
}
}
crate::config::Directive::UriReplace { find, replace } => {
handle_uri_replace(find, replace, &mut modified_path);
}
crate::config::Directive::StripPrefix { prefix } => {
handle_strip_prefix(prefix, &mut modified_path);
}
crate::config::Directive::HandlePath {
pattern,
directives: nested_directives,
} => {
if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
info!(" Matched handle_path: {}", pattern);
return process_directives(nested_directives, req, &remaining_path);
}
}
crate::config::Directive::Method {
methods,
directives: nested_directives,
} => {
if handle_method(methods, req) {
info!(" Matched method directive");
return process_directives(nested_directives, req, &modified_path);
}
}
crate::config::Directive::Redirect { status, url } => {
return Ok(handle_redirect(status, url));
}
crate::config::Directive::Respond { status, body } => {
return Ok(handle_respond(status, body));
}
crate::config::Directive::ReverseProxy {
to,
connect_timeout,
read_timeout,
} => {
return Ok(handle_reverse_proxy(
to,
&modified_path,
*connect_timeout,
*read_timeout,
));
}
}
}
Err(format!(
"No action directive (respond or reverse_proxy) found in configuration for path: {}",
current_path
))
}
pub async fn proxy(
mut req: Request<Incoming>,
client: Client<HttpsConnector<HttpConnector>, Incoming>,
config: Arc<Config>,
remote_addr: std::net::SocketAddr,
) -> Result<Response<ResponseBody>, Error> {
let path = req.uri().path().to_string();
let host = req
.headers()
.get(hyper::header::HOST)
.and_then(|h| h.to_str().ok())
.unwrap_or("localhost");
if tracing::enabled!(tracing::Level::INFO) {
}
let site_config = match config.sites.get(host) {
Some(config) => config,
None => {
error!("No configuration found for host: {}", host);
return Ok(error_response(
StatusCode::NOT_FOUND,
&format!("No configuration found for host: {}", host),
));
}
};
let action_result =
process_directives(&site_config.directives, &mut req, &path).map_err(anyhow::Error::msg)?;
match action_result {
ActionResult::Redirect { status, url } => {
let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::FOUND);
let boxed: ResponseBody = Full::new(Bytes::from(url.clone()))
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.boxed();
Ok(Response::builder()
.status(status_code)
.header("Location", &url)
.body(boxed)?)
}
ActionResult::Respond { status, body } => {
let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
let boxed: ResponseBody = Full::new(Bytes::from(body))
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.boxed();
Ok(Response::builder().status(status_code).body(boxed)?)
}
ActionResult::ReverseProxy {
backend_url,
path_to_send,
connect_timeout: _,
read_timeout,
} => {
let backend_with_proto =
if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
backend_url
} else {
format!("http://{}", backend_url)
};
let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
parts.path_and_query = Some(path_to_send.parse()?);
let new_uri = Uri::from_parts(parts)?;
if tracing::enabled!(tracing::Level::INFO) {
}
*req.uri_mut() = new_uri.clone();
let original_host_header = req.headers().get(hyper::header::HOST).cloned();
req.headers_mut().remove(hyper::header::HOST);
if let Some(authority) = new_uri.authority() {
if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>() {
req.headers_mut().insert(hyper::header::HOST, host_value);
}
}
if let Some(host_value) = original_host_header.clone() {
req.headers_mut().insert("X-Forwarded-Host", host_value);
}
let original_scheme = req.uri().scheme_str().unwrap_or("http");
match original_scheme {
"http" => {
req.headers_mut().insert(
"X-Forwarded-Proto",
hyper::header::HeaderValue::from_static("http"),
);
}
"https" => {
req.headers_mut().insert(
"X-Forwarded-Proto",
hyper::header::HeaderValue::from_static("https"),
);
}
_ => {} }
if let Ok(ip_value) =
hyper::header::HeaderValue::from_str(&remote_addr.ip().to_string())
{
req.headers_mut().insert("X-Forwarded-For", ip_value);
}
req.headers_mut().remove(header::CONNECTION);
req.headers_mut().remove("accept-encoding");
let backend_timeout = read_timeout.unwrap_or(30);
match timeout(Duration::from_secs(backend_timeout), client.request(req)).await {
Ok(Ok(response)) => {
let status = response.status();
let headers = response.headers().clone();
if tracing::enabled!(tracing::Level::INFO) {
}
let mut builder = Response::builder().status(status);
for (name, value) in headers.iter() {
if !is_hop_header(name) && name != header::CONTENT_LENGTH {
builder = builder.header(name, value);
}
}
let (_, incoming_body) = response.into_parts();
let boxed: ResponseBody = incoming_body
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.boxed();
Ok(builder.body(boxed)?)
}
Ok(Err(e)) => {
error!("Backend connection failed: {:?}", e);
if e.is_connect() {
error!(" Reason: Connection refused - backend unavailable");
} else {
error!(" Reason: Other connection error");
}
Ok(error_response(
StatusCode::BAD_GATEWAY,
"Backend service unavailable",
))
}
Err(_) => {
error!(
"Backend request timed out after {} seconds",
backend_timeout
);
Ok(error_response(
StatusCode::GATEWAY_TIMEOUT,
"Backend request timed out",
))
}
}
}
}
}
fn error_response(status: StatusCode, message: &str) -> Response<ResponseBody> {
let body = format!(
r#"<!DOCTYPE html>
<html>
<head><title>{} {}</title></head>
<body>
<h1>{} {}</h1>
<p>{}</p>
<hr>
<p><em>Rust Proxy Server</em></p>
</body>
</html>"#,
status.as_u16(),
status.canonical_reason().unwrap_or("Error"),
status.as_u16(),
status.canonical_reason().unwrap_or("Error"),
message
);
let full = Full::new(Bytes::from(body));
let boxed: ResponseBody = full
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
.boxed();
Response::builder()
.status(status)
.header("Content-Type", "text/html; charset=utf-8")
.body(boxed)
.unwrap()
}
pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
if let Some(prefix) = pattern.strip_suffix("/*") {
if path.starts_with(prefix) {
let remaining = path.strip_prefix(prefix).unwrap_or(path);
Some(remaining.to_string())
} else {
None
}
} else if pattern == path {
Some("/".to_string()) } else {
None
}
}