1use axum::{
6 body::Body,
7 extract::Request,
8 http::HeaderName,
9 response::Response,
10};
11use reqwest::Client;
12use std::time::Duration;
13
14use crate::error::{ProxyError, ProxyResult};
15
16#[derive(Debug, Clone)]
18pub struct ProxyConfig {
19 pub preserve_host: bool,
21 pub timeout: Duration,
23 pub max_body_size: usize,
25 pub add_forwarded_headers: bool,
27}
28
29impl Default for ProxyConfig {
30 fn default() -> Self {
31 Self {
32 preserve_host: false,
33 timeout: Duration::from_secs(30),
34 max_body_size: 10 * 1024 * 1024, add_forwarded_headers: true,
36 }
37 }
38}
39
40pub async fn proxy_request(
53 request: Request,
54 backend_url: &str,
55 config: ProxyConfig,
56) -> ProxyResult<Response> {
57 let client = Client::builder()
58 .timeout(config.timeout)
59 .build()
60 .map_err(|e| ProxyError::ClientCreation(e.to_string()))?;
61
62 let method = request.method().clone();
64 let uri = request.uri().clone();
65 let headers = request.headers().clone();
66
67 let body = axum::body::to_bytes(request.into_body(), config.max_body_size)
68 .await
69 .map_err(|e| ProxyError::RequestBody(e.to_string()))?;
70
71 let path_and_query = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
73 let target_url = format!("{}{}", backend_url.trim_end_matches('/'), path_and_query);
74
75 tracing::debug!(
76 method = %method,
77 uri = %uri,
78 target = %target_url,
79 "Proxying request"
80 );
81
82 let mut req_builder = client.request(method.clone(), &target_url);
84
85 for (name, value) in headers.iter() {
87 if !is_hop_by_hop_header(name) {
88 req_builder = req_builder.header(name, value);
89 }
90 }
91
92 if config.add_forwarded_headers
94 && let Some(host) = headers.get("host").and_then(|h| h.to_str().ok())
95 {
96 req_builder = req_builder.header("X-Forwarded-Host", host);
97 }
98 let backend_response = req_builder
102 .body(body.to_vec())
103 .send()
104 .await
105 .map_err(|e| ProxyError::BackendRequest(e.to_string()))?;
106
107 let status = backend_response.status();
109 let response_headers = backend_response.headers().clone();
110 let body_bytes = backend_response
111 .bytes()
112 .await
113 .map_err(|e| ProxyError::ResponseBody(e.to_string()))?;
114
115 let mut response = Response::new(Body::from(body_bytes));
116 *response.status_mut() = status;
117
118 for (name, value) in response_headers.iter() {
120 if !is_hop_by_hop_header(name) {
121 response.headers_mut().insert(name, value.clone());
122 }
123 }
124
125 Ok(response)
126}
127
128fn is_hop_by_hop_header(name: &HeaderName) -> bool {
133 matches!(
134 name.as_str().to_lowercase().as_str(),
135 "connection"
136 | "keep-alive"
137 | "proxy-authenticate"
138 | "proxy-authorization"
139 | "te"
140 | "trailers"
141 | "transfer-encoding"
142 | "upgrade"
143 )
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_hop_by_hop_headers() {
152 assert!(is_hop_by_hop_header(&HeaderName::from_static("connection")));
153 assert!(is_hop_by_hop_header(&HeaderName::from_static("keep-alive")));
154 assert!(is_hop_by_hop_header(&HeaderName::from_static("upgrade")));
155
156 assert!(!is_hop_by_hop_header(&HeaderName::from_static("content-type")));
157 assert!(!is_hop_by_hop_header(&HeaderName::from_static("authorization")));
158 }
159
160 #[test]
161 fn test_default_config() {
162 let config = ProxyConfig::default();
163 assert!(!config.preserve_host);
164 assert!(config.add_forwarded_headers);
165 assert_eq!(config.timeout, Duration::from_secs(30));
166 assert_eq!(config.max_body_size, 10 * 1024 * 1024);
167 }
168}