use axum::{
body::Body,
extract::Request,
http::HeaderName,
response::Response,
};
use reqwest::Client;
use std::time::Duration;
use crate::error::{ProxyError, ProxyResult};
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub preserve_host: bool,
pub timeout: Duration,
pub max_body_size: usize,
pub add_forwarded_headers: bool,
}
impl Default for ProxyConfig {
fn default() -> Self {
Self {
preserve_host: false,
timeout: Duration::from_secs(30),
max_body_size: 10 * 1024 * 1024, add_forwarded_headers: true,
}
}
}
pub async fn proxy_request(
request: Request,
backend_url: &str,
config: ProxyConfig,
) -> ProxyResult<Response> {
let client = Client::builder()
.timeout(config.timeout)
.build()
.map_err(|e| ProxyError::ClientCreation(e.to_string()))?;
let method = request.method().clone();
let uri = request.uri().clone();
let headers = request.headers().clone();
let body = axum::body::to_bytes(request.into_body(), config.max_body_size)
.await
.map_err(|e| ProxyError::RequestBody(e.to_string()))?;
let path_and_query = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
let target_url = format!("{}{}", backend_url.trim_end_matches('/'), path_and_query);
tracing::debug!(
method = %method,
uri = %uri,
target = %target_url,
"Proxying request"
);
let mut req_builder = client.request(method.clone(), &target_url);
for (name, value) in headers.iter() {
if !is_hop_by_hop_header(name) {
req_builder = req_builder.header(name, value);
}
}
if config.add_forwarded_headers
&& let Some(host) = headers.get("host").and_then(|h| h.to_str().ok())
{
req_builder = req_builder.header("X-Forwarded-Host", host);
}
let backend_response = req_builder
.body(body.to_vec())
.send()
.await
.map_err(|e| ProxyError::BackendRequest(e.to_string()))?;
let status = backend_response.status();
let response_headers = backend_response.headers().clone();
let body_bytes = backend_response
.bytes()
.await
.map_err(|e| ProxyError::ResponseBody(e.to_string()))?;
let mut response = Response::new(Body::from(body_bytes));
*response.status_mut() = status;
for (name, value) in response_headers.iter() {
if !is_hop_by_hop_header(name) {
response.headers_mut().insert(name, value.clone());
}
}
Ok(response)
}
fn is_hop_by_hop_header(name: &HeaderName) -> bool {
matches!(
name.as_str().to_lowercase().as_str(),
"connection"
| "keep-alive"
| "proxy-authenticate"
| "proxy-authorization"
| "te"
| "trailers"
| "transfer-encoding"
| "upgrade"
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hop_by_hop_headers() {
assert!(is_hop_by_hop_header(&HeaderName::from_static("connection")));
assert!(is_hop_by_hop_header(&HeaderName::from_static("keep-alive")));
assert!(is_hop_by_hop_header(&HeaderName::from_static("upgrade")));
assert!(!is_hop_by_hop_header(&HeaderName::from_static("content-type")));
assert!(!is_hop_by_hop_header(&HeaderName::from_static("authorization")));
}
#[test]
fn test_default_config() {
let config = ProxyConfig::default();
assert!(!config.preserve_host);
assert!(config.add_forwarded_headers);
assert_eq!(config.timeout, Duration::from_secs(30));
assert_eq!(config.max_body_size, 10 * 1024 * 1024);
}
}