use http_body_util::{BodyExt, Full};
use hyper::{Request, Response, StatusCode, body::Incoming};
use std::convert::Infallible;
use std::sync::Arc;
use crate::error::WiseGateError;
use crate::types::{ConfigProvider, RateLimiter};
use crate::{auth, headers, ip_filter, rate_limiter};
pub async fn handle_request<C: ConfigProvider>(
req: Request<Incoming>,
forward_host: Arc<str>,
forward_port: u16,
limiter: RateLimiter,
config: Arc<C>,
http_client: reqwest::Client,
) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
let real_client_ip: Option<String> =
match ip_filter::extract_and_validate_real_ip(req.headers(), config.as_ref()) {
Some(ip) => Some(ip),
None => {
if config.allowed_proxy_ips().is_none() {
None
} else {
let err = WiseGateError::InvalidIp("missing or invalid proxy headers".into());
return Ok(create_error_response(err.status_code(), err.user_message()));
}
}
};
if let Some(ref ip) = real_client_ip
&& ip_filter::is_ip_blocked(ip, config.as_ref())
{
let err = WiseGateError::IpBlocked(ip.clone());
return Ok(create_error_response(err.status_code(), err.user_message()));
}
let request_path = req.uri().path();
if is_url_pattern_blocked(request_path, config.as_ref()) {
let err = WiseGateError::PatternBlocked(request_path.to_string());
return Ok(create_error_response(err.status_code(), err.user_message()));
}
let request_method = req.method().as_str();
if is_method_blocked(request_method, config.as_ref()) {
let err = WiseGateError::MethodBlocked(request_method.to_string());
return Ok(create_error_response(err.status_code(), err.user_message()));
}
if config.is_auth_enabled() {
let auth_header = req
.headers()
.get(headers::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
let basic_auth_enabled = config.is_basic_auth_enabled();
let bearer_auth_enabled = config.is_bearer_auth_enabled();
let basic_auth_passed =
basic_auth_enabled && auth::check_basic_auth(auth_header, config.auth_credentials());
let bearer_auth_passed =
bearer_auth_enabled && auth::check_bearer_token(auth_header, config.bearer_token());
if !basic_auth_passed && !bearer_auth_passed {
return Ok(create_unauthorized_response(config.auth_realm()));
}
}
if let Some(ref ip) = real_client_ip
&& !rate_limiter::check_rate_limit(&limiter, ip, config.as_ref()).await
{
let err = WiseGateError::RateLimitExceeded(ip.clone());
return Ok(create_error_response(err.status_code(), err.user_message()));
}
let mut req = req;
req.headers_mut().remove(headers::X_REAL_IP);
if let Some(ref ip) = real_client_ip
&& let Ok(header_value) = ip.parse()
{
req.headers_mut().insert(headers::X_REAL_IP, header_value);
}
forward_request(
req,
&forward_host,
forward_port,
config.as_ref(),
&http_client,
)
.await
}
async fn forward_request(
req: Request<Incoming>,
host: &str,
port: u16,
config: &impl ConfigProvider,
http_client: &reqwest::Client,
) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
let proxy_config = config.proxy_config();
let (parts, body) = req.into_parts();
if proxy_config.max_body_size > 0
&& let Some(content_length) = parts
.headers
.get(headers::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<usize>().ok())
&& content_length > proxy_config.max_body_size
{
let err = WiseGateError::BodyTooLarge {
size: content_length,
max: proxy_config.max_body_size,
};
return Ok(create_error_response(err.status_code(), err.user_message()));
}
let body_bytes = match body.collect().await {
Ok(bytes) => {
let collected_bytes = bytes.to_bytes();
if proxy_config.max_body_size > 0 && collected_bytes.len() > proxy_config.max_body_size
{
let err = WiseGateError::BodyTooLarge {
size: collected_bytes.len(),
max: proxy_config.max_body_size,
};
return Ok(create_error_response(err.status_code(), err.user_message()));
}
collected_bytes
}
Err(e) => {
let err = WiseGateError::BodyReadError(e.to_string());
return Ok(create_error_response(err.status_code(), err.user_message()));
}
};
let strip_auth = config.is_auth_enabled() && !config.forward_authorization_header();
forward_with_reqwest(parts, body_bytes, host, port, http_client, strip_auth).await
}
async fn forward_with_reqwest(
parts: hyper::http::request::Parts,
body_bytes: bytes::Bytes,
host: &str,
port: u16,
client: &reqwest::Client,
strip_auth: bool,
) -> Result<Response<Full<bytes::Bytes>>, Infallible> {
let destination_uri = format!(
"http://{}:{}{}",
host,
port,
parts.uri.path_and_query().map_or("", |pq| pq.as_str())
);
let method = match reqwest::Method::from_bytes(parts.method.as_str().as_bytes()) {
Ok(m) => m,
Err(_) => {
let err =
WiseGateError::MethodBlocked(format!("{} (unsupported)", parts.method.as_str()));
return Ok(create_error_response(err.status_code(), err.user_message()));
}
};
let mut req_builder = client.request(method, &destination_uri);
for (name, value) in parts.headers.iter() {
if name != headers::HOST
&& name != headers::CONTENT_LENGTH
&& !(strip_auth && name == headers::AUTHORIZATION)
&& !headers::is_hop_by_hop(name.as_str())
&& let Ok(header_value) = value.to_str()
{
req_builder = req_builder.header(name.as_str(), header_value);
}
}
if !body_bytes.is_empty() {
req_builder = req_builder.body(body_bytes);
}
match req_builder.send().await {
Ok(response) => {
let status = response.status();
let resp_headers = response.headers().clone();
match response.bytes().await {
Ok(body_bytes) => {
let mut hyper_response = match Response::builder()
.status(status.as_u16())
.body(Full::new(body_bytes))
{
Ok(resp) => resp,
Err(e) => {
let err = WiseGateError::ProxyError(format!(
"Failed to build response: {}",
e
));
return Ok(create_error_response(
err.status_code(),
err.user_message(),
));
}
};
for (name, value) in resp_headers.iter() {
if !headers::is_hop_by_hop(name.as_str())
&& let (Ok(hyper_name), Ok(hyper_value)) = (
hyper::header::HeaderName::from_bytes(name.as_str().as_bytes()),
hyper::header::HeaderValue::from_bytes(value.as_bytes()),
)
{
hyper_response.headers_mut().insert(hyper_name, hyper_value);
}
}
Ok(hyper_response)
}
Err(e) => {
let err = WiseGateError::BodyReadError(format!("response: {}", e));
Ok(create_error_response(err.status_code(), err.user_message()))
}
}
}
Err(err) => {
let wise_err = if err.is_timeout() {
WiseGateError::UpstreamTimeout(err.to_string())
} else if err.is_connect() {
WiseGateError::UpstreamConnectionFailed(err.to_string())
} else {
WiseGateError::ProxyError(err.to_string())
};
Ok(create_error_response(
wise_err.status_code(),
wise_err.user_message(),
))
}
}
}
pub fn create_error_response(status: StatusCode, message: &str) -> Response<Full<bytes::Bytes>> {
Response::builder()
.status(status)
.header(headers::CONTENT_TYPE, "text/plain")
.body(Full::new(bytes::Bytes::from(message.to_string())))
.unwrap_or_else(|_| {
Response::new(Full::new(bytes::Bytes::from("Internal Server Error")))
})
}
pub fn create_unauthorized_response(realm: &str) -> Response<Full<bytes::Bytes>> {
let sanitized_realm = realm.replace('\\', "\\\\").replace('"', "\\\"");
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(
headers::WWW_AUTHENTICATE,
format!("Basic realm=\"{}\"", sanitized_realm),
)
.header(headers::CONTENT_TYPE, "text/plain")
.body(Full::new(bytes::Bytes::from("401 Unauthorized")))
.unwrap_or_else(|_| Response::new(Full::new(bytes::Bytes::from("401 Unauthorized"))))
}
fn is_url_pattern_blocked(path: &str, config: &impl ConfigProvider) -> bool {
let blocked_patterns = config.blocked_patterns();
if blocked_patterns.is_empty() {
return false;
}
let decoded_path = url_decode(path);
let has_encoding = decoded_path != path;
let path_lower = path.to_lowercase();
let decoded_lower = if has_encoding {
Some(decoded_path.to_lowercase())
} else {
None
};
blocked_patterns.iter().any(|pattern| {
path_lower.contains(pattern.as_str())
|| decoded_lower
.as_ref()
.is_some_and(|dl| dl.contains(pattern.as_str()))
})
}
fn url_decode(input: &str) -> String {
let mut bytes = Vec::with_capacity(input.len());
let input_bytes = input.as_bytes();
let mut i = 0;
while i < input_bytes.len() {
if input_bytes[i] == b'%' && i + 2 < input_bytes.len() {
let hi = hex_digit(input_bytes[i + 1]);
let lo = hex_digit(input_bytes[i + 2]);
if let (Some(h), Some(l)) = (hi, lo) {
bytes.push(h << 4 | l);
i += 3;
continue;
}
}
bytes.push(input_bytes[i]);
i += 1;
}
String::from_utf8_lossy(&bytes).into_owned()
}
fn hex_digit(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
fn is_method_blocked(method: &str, config: &impl ConfigProvider) -> bool {
let blocked_methods = config.blocked_methods();
blocked_methods
.iter()
.any(|blocked_method| blocked_method.eq_ignore_ascii_case(method))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::TestConfig;
use http_body_util::BodyExt;
#[test]
fn test_url_decode_no_encoding() {
assert_eq!(url_decode("/path/to/file"), "/path/to/file");
assert_eq!(url_decode("hello"), "hello");
assert_eq!(url_decode(""), "");
}
#[test]
fn test_url_decode_simple_encoding() {
assert_eq!(url_decode("%20"), " ");
assert_eq!(url_decode("hello%20world"), "hello world");
assert_eq!(url_decode("%2F"), "/");
}
#[test]
fn test_url_decode_dot_encoding() {
assert_eq!(url_decode("%2e"), ".");
assert_eq!(url_decode("%2E"), ".");
assert_eq!(url_decode(".%2ephp"), "..php");
}
#[test]
fn test_url_decode_php_bypass() {
assert_eq!(url_decode(".ph%70"), ".php");
assert_eq!(url_decode("%2ephp"), ".php");
assert_eq!(url_decode(".%70%68%70"), ".php");
}
#[test]
fn test_url_decode_env_bypass() {
assert_eq!(url_decode(".%65nv"), ".env");
assert_eq!(url_decode("%2eenv"), ".env");
assert_eq!(url_decode("%2e%65%6e%76"), ".env");
}
#[test]
fn test_url_decode_multiple_encodings() {
assert_eq!(url_decode("%2F%2e%2e%2Fetc%2Fpasswd"), "/../etc/passwd");
}
#[test]
fn test_url_decode_invalid_hex() {
assert_eq!(url_decode("%GG"), "%GG");
assert_eq!(url_decode("%"), "%");
assert_eq!(url_decode("%2"), "%2");
assert_eq!(url_decode("%ZZ"), "%ZZ");
}
#[test]
fn test_url_decode_mixed_content() {
assert_eq!(url_decode("path%2Fto%2Ffile.txt"), "path/to/file.txt");
assert_eq!(url_decode("hello%20%26%20world"), "hello & world");
}
#[test]
fn test_url_decode_unicode() {
assert_eq!(url_decode("%C3%A9"), "é"); assert_eq!(url_decode("caf%C3%A9"), "café");
}
#[test]
fn test_url_pattern_blocked_simple() {
let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
assert!(is_url_pattern_blocked("/file.php", &config));
assert!(is_url_pattern_blocked("/.env", &config));
assert!(is_url_pattern_blocked("/path/to/file.php", &config));
}
#[test]
fn test_url_pattern_not_blocked() {
let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env"]);
assert!(!is_url_pattern_blocked("/file.html", &config));
assert!(!is_url_pattern_blocked("/path/to/file.js", &config));
assert!(!is_url_pattern_blocked("/", &config));
}
#[test]
fn test_url_pattern_blocked_empty_patterns() {
let config = TestConfig::new();
assert!(!is_url_pattern_blocked("/file.php", &config));
assert!(!is_url_pattern_blocked("/.env", &config));
}
#[test]
fn test_url_pattern_blocked_bypass_attempt() {
let config = TestConfig::new().with_blocked_patterns(vec![".php", ".env", "admin"]);
assert!(is_url_pattern_blocked("/.ph%70", &config)); assert!(is_url_pattern_blocked("/%2eenv", &config)); assert!(is_url_pattern_blocked("/adm%69n", &config)); }
#[test]
fn test_url_pattern_blocked_double_encoding_attempt() {
let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
assert!(is_url_pattern_blocked("/.ph%70", &config));
}
#[test]
fn test_url_pattern_blocked_case_insensitive() {
let config = TestConfig::new().with_blocked_patterns(vec![".php"]);
assert!(is_url_pattern_blocked("/file.PHP", &config));
assert!(is_url_pattern_blocked("/file.php", &config));
assert!(is_url_pattern_blocked("/file.Php", &config));
}
#[test]
fn test_url_pattern_blocked_partial_match() {
let config = TestConfig::new().with_blocked_patterns(vec!["admin"]);
assert!(is_url_pattern_blocked("/admin/panel", &config));
assert!(is_url_pattern_blocked("/path/admin", &config));
assert!(is_url_pattern_blocked("/administrator", &config)); }
#[test]
fn test_method_blocked() {
let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
assert!(is_method_blocked("TRACE", &config));
assert!(is_method_blocked("CONNECT", &config));
}
#[test]
fn test_method_not_blocked() {
let config = TestConfig::new().with_blocked_methods(vec!["TRACE", "CONNECT"]);
assert!(!is_method_blocked("GET", &config));
assert!(!is_method_blocked("POST", &config));
assert!(!is_method_blocked("PUT", &config));
assert!(!is_method_blocked("DELETE", &config));
}
#[test]
fn test_method_blocked_empty_list() {
let config = TestConfig::new();
assert!(!is_method_blocked("TRACE", &config));
assert!(!is_method_blocked("GET", &config));
}
#[test]
fn test_method_blocked_case_insensitive() {
let config = TestConfig::new().with_blocked_methods(vec!["TRACE"]);
assert!(is_method_blocked("TRACE", &config));
assert!(is_method_blocked("trace", &config));
assert!(is_method_blocked("Trace", &config));
}
#[test]
fn test_create_error_response_status() {
let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
assert_eq!(response.status(), StatusCode::NOT_FOUND);
let response = create_error_response(StatusCode::FORBIDDEN, "Forbidden");
assert_eq!(response.status(), StatusCode::FORBIDDEN);
let response = create_error_response(StatusCode::TOO_MANY_REQUESTS, "Rate limited");
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[test]
fn test_create_error_response_content_type() {
let response = create_error_response(StatusCode::NOT_FOUND, "Not Found");
assert_eq!(
response.headers().get("content-type").unwrap(),
"text/plain"
);
}
#[tokio::test]
async fn test_create_error_response_body() {
let response = create_error_response(StatusCode::NOT_FOUND, "Resource not found");
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body, "Resource not found");
}
#[tokio::test]
async fn test_create_error_response_empty_message() {
let response = create_error_response(StatusCode::NO_CONTENT, "");
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body, "");
}
#[test]
fn test_unauthorized_response_status() {
let response = create_unauthorized_response("WiseGate");
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn test_unauthorized_response_www_authenticate_header() {
let response = create_unauthorized_response("WiseGate");
let header = response
.headers()
.get("www-authenticate")
.unwrap()
.to_str()
.unwrap();
assert_eq!(header, "Basic realm=\"WiseGate\"");
}
#[test]
fn test_unauthorized_response_realm_with_quotes() {
let response = create_unauthorized_response("My \"Realm\"");
let header = response
.headers()
.get("www-authenticate")
.unwrap()
.to_str()
.unwrap();
assert_eq!(header, "Basic realm=\"My \\\"Realm\\\"\"");
}
#[test]
fn test_unauthorized_response_realm_with_backslash() {
let response = create_unauthorized_response("My\\Realm");
let header = response
.headers()
.get("www-authenticate")
.unwrap()
.to_str()
.unwrap();
assert_eq!(header, "Basic realm=\"My\\\\Realm\"");
}
#[test]
fn test_unauthorized_response_content_type() {
let response = create_unauthorized_response("WiseGate");
assert_eq!(
response.headers().get("content-type").unwrap(),
"text/plain"
);
}
#[test]
fn test_url_decode_double_encoding_not_decoded_twice() {
assert_eq!(url_decode("%252e"), "%2e");
assert_eq!(url_decode("%2565nv"), "%65nv");
}
}