1use axum::http::{header, HeaderMap, HeaderName, HeaderValue, Uri};
4use hyper::body::Incoming;
5use hyper_util::{client::legacy::Client, rt::TokioExecutor};
6use std::sync::Arc;
7use std::time::Duration;
8
9use crate::prelude::*;
10use crate::ProxySiteEntry;
11
12const HOP_BY_HOP_HEADERS: &[&str] = &[
14 "connection",
15 "keep-alive",
16 "proxy-authenticate",
17 "proxy-authorization",
18 "te",
19 "trailers",
20 "transfer-encoding",
21];
22
23fn is_hop_by_hop(name: &HeaderName) -> bool {
25 HOP_BY_HOP_HEADERS.iter().any(|h| name.as_str().eq_ignore_ascii_case(h))
26}
27
28fn is_websocket_upgrade(headers: &HeaderMap) -> bool {
30 headers
31 .get(header::UPGRADE)
32 .and_then(|v| v.to_str().ok())
33 .is_some_and(|v| v.eq_ignore_ascii_case("websocket"))
34}
35
36fn build_backend_uri(entry: &ProxySiteEntry, original_uri: &Uri) -> ClResult<Uri> {
38 let mut backend = entry.backend_url.clone();
39 let combined_path = format!("{}{}", backend.path().trim_end_matches('/'), original_uri.path());
40 backend.set_path(&combined_path);
41 backend.set_query(original_uri.query());
42 debug!("Proxy backend URI: {} (combined_path={:?})", backend.as_str(), combined_path);
43 backend
44 .as_str()
45 .parse::<Uri>()
46 .map_err(|e| Error::Internal(format!("failed to build backend URI: {}", e)))
47}
48
49fn copy_headers(src: &HeaderMap, dst: &mut HeaderMap, is_websocket: bool) {
51 for (name, value) in src {
52 if is_hop_by_hop(name) {
54 if is_websocket && name == header::UPGRADE {
55 dst.insert(name.clone(), value.clone());
56 }
57 continue;
58 }
59 dst.append(name.clone(), value.clone());
60 }
61}
62
63pub async fn handle_proxy_request(
65 entry: Arc<ProxySiteEntry>,
66 req: hyper::Request<Incoming>,
67 peer_addr: &str,
68) -> Result<hyper::Response<Incoming>, Error> {
69 let is_ws = is_websocket_upgrade(req.headers()) && entry.config.websocket.unwrap_or(true);
70
71 if is_ws {
72 return handle_websocket_proxy(entry, req, peer_addr).await;
73 }
74
75 let backend_uri = build_backend_uri(&entry, req.uri())?;
76
77 let mut backend_headers = HeaderMap::new();
79 copy_headers(req.headers(), &mut backend_headers, false);
80
81 let preserve_host = entry.config.preserve_host.unwrap_or(true);
83 if preserve_host {
84 if let Some(host) = req.headers().get(header::HOST) {
86 backend_headers.insert(header::HOST, host.clone());
87 }
88 } else if let Some(host) = entry.backend_url.host_str() {
89 let host_val = if let Some(port) = entry.backend_url.port() {
91 format!("{}:{}", host, port)
92 } else {
93 host.to_string()
94 };
95 if let Ok(hv) = HeaderValue::from_str(&host_val) {
96 backend_headers.insert(header::HOST, hv);
97 }
98 }
99
100 let forward_headers = if entry.proxy_type.as_ref() == "basic" {
102 true
103 } else {
104 entry.config.forward_headers.unwrap_or(true)
105 };
106 if forward_headers {
107 if let Ok(hv) = HeaderValue::from_str(peer_addr) {
108 backend_headers.insert(HeaderName::from_static("x-forwarded-for"), hv.clone());
109 backend_headers.insert(HeaderName::from_static("x-real-ip"), hv);
110 }
111 backend_headers.insert(
112 HeaderName::from_static("x-forwarded-proto"),
113 HeaderValue::from_static("https"),
114 );
115 if let Ok(hv) = HeaderValue::from_str(&entry.domain) {
116 backend_headers.insert(HeaderName::from_static("x-forwarded-host"), hv);
117 }
118 }
119
120 if let Some(custom_headers) = &entry.config.custom_headers {
122 for (name, value) in custom_headers {
123 if let (Ok(hn), Ok(hv)) =
124 (HeaderName::from_bytes(name.as_bytes()), HeaderValue::from_str(value))
125 {
126 backend_headers.insert(hn, hv);
127 }
128 }
129 }
130
131 let method = req.method().clone();
133 let body = req.into_body();
134
135 let mut backend_req = hyper::Request::builder().method(method).uri(backend_uri);
136
137 if let Some(headers) = backend_req.headers_mut() {
138 *headers = backend_headers;
139 }
140
141 let backend_req = backend_req
142 .body(body)
143 .map_err(|e| Error::Internal(format!("failed to build backend request: {}", e)))?;
144
145 let connect_timeout =
147 Duration::from_secs(u64::from(entry.config.connect_timeout_secs.unwrap_or(5)));
148 let read_timeout = Duration::from_secs(u64::from(entry.config.read_timeout_secs.unwrap_or(30)));
149
150 let scheme = entry.backend_url.scheme();
152 match send_backend_request(scheme, connect_timeout, read_timeout, backend_req).await {
153 Ok(mut backend_resp) => {
154 let headers_to_remove: Vec<HeaderName> = backend_resp
156 .headers()
157 .keys()
158 .filter(|name| is_hop_by_hop(name))
159 .cloned()
160 .collect();
161 for name in headers_to_remove {
162 backend_resp.headers_mut().remove(&name);
163 }
164 Ok(backend_resp)
165 }
166 Err(e @ Error::Timeout) => {
167 warn!("Proxy backend timeout for {}", entry.domain);
168 Err(e)
169 }
170 Err(e) => {
171 warn!("Proxy backend error for {}: {}", entry.domain, e);
172 Err(e)
173 }
174 }
175}
176
177async fn handle_websocket_proxy(
179 entry: Arc<ProxySiteEntry>,
180 req: hyper::Request<Incoming>,
181 peer_addr: &str,
182) -> Result<hyper::Response<Incoming>, Error> {
183 let backend_uri = build_backend_uri(&entry, req.uri())?;
186
187 let mut backend_headers = HeaderMap::new();
188 for (name, value) in req.headers() {
190 if is_hop_by_hop(name) && name != header::UPGRADE {
191 continue;
192 }
193 backend_headers.append(name.clone(), value.clone());
194 }
195
196 let preserve_host = entry.config.preserve_host.unwrap_or(true);
198 if !preserve_host {
199 if let Some(host) = entry.backend_url.host_str() {
200 let host_val = if let Some(port) = entry.backend_url.port() {
201 format!("{}:{}", host, port)
202 } else {
203 host.to_string()
204 };
205 if let Ok(hv) = HeaderValue::from_str(&host_val) {
206 backend_headers.insert(header::HOST, hv);
207 }
208 }
209 }
210
211 let forward_headers = if entry.proxy_type.as_ref() == "basic" {
213 true
214 } else {
215 entry.config.forward_headers.unwrap_or(true)
216 };
217 if forward_headers {
218 if let Ok(hv) = HeaderValue::from_str(peer_addr) {
219 backend_headers.insert(HeaderName::from_static("x-forwarded-for"), hv);
220 }
221 backend_headers.insert(
222 HeaderName::from_static("x-forwarded-proto"),
223 HeaderValue::from_static("https"),
224 );
225 }
226
227 backend_headers.insert(header::CONNECTION, HeaderValue::from_static("Upgrade"));
229
230 let method = req.method().clone();
231 let body = req.into_body();
232
233 let mut backend_req = hyper::Request::builder().method(method).uri(backend_uri);
234
235 if let Some(headers) = backend_req.headers_mut() {
236 *headers = backend_headers;
237 }
238
239 let backend_req = backend_req
240 .body(body)
241 .map_err(|e| Error::Internal(format!("failed to build ws backend request: {}", e)))?;
242
243 let connect_timeout =
245 Duration::from_secs(u64::from(entry.config.connect_timeout_secs.unwrap_or(5)));
246
247 let scheme = entry.backend_url.scheme();
248 match send_backend_request(scheme, connect_timeout, connect_timeout, backend_req).await {
249 Ok(backend_resp) => Ok(backend_resp),
250 Err(e @ Error::Timeout) => {
251 warn!("WebSocket proxy backend timeout for {}", entry.domain);
252 Err(e)
253 }
254 Err(e) => {
255 warn!("WebSocket proxy backend error for {}: {}", entry.domain, e);
256 Err(e)
257 }
258 }
259}
260
261async fn send_backend_request(
263 scheme: &str,
264 connect_timeout: Duration,
265 timeout: Duration,
266 req: hyper::Request<Incoming>,
267) -> Result<hyper::Response<Incoming>, Error> {
268 let result = if scheme == "https" {
269 let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
270 .with_native_roots()
271 .map_err(|_| Error::ConfigError("no native root CA certificates found".into()))?
272 .https_only()
273 .enable_http1()
274 .build();
275 let client: Client<_, Incoming> = Client::builder(TokioExecutor::new())
276 .pool_idle_timeout(connect_timeout)
277 .build(https_connector);
278 tokio::time::timeout(timeout, client.request(req)).await
279 } else {
280 let http_connector = hyper_util::client::legacy::connect::HttpConnector::new();
281 let client: Client<_, Incoming> = Client::builder(TokioExecutor::new())
282 .pool_idle_timeout(connect_timeout)
283 .build(http_connector);
284 tokio::time::timeout(timeout, client.request(req)).await
285 };
286 match result {
287 Ok(Ok(resp)) => Ok(resp),
288 Ok(Err(_)) => Err(Error::NetworkError("bad gateway".into())),
289 Err(_) => Err(Error::Timeout),
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::ProxySiteConfig;
297
298 #[test]
299 fn test_is_hop_by_hop() {
300 assert!(is_hop_by_hop(&HeaderName::from_static("connection")));
301 assert!(is_hop_by_hop(&HeaderName::from_static("keep-alive")));
302 assert!(is_hop_by_hop(&HeaderName::from_static("transfer-encoding")));
303 assert!(!is_hop_by_hop(&HeaderName::from_static("content-type")));
304 assert!(!is_hop_by_hop(&HeaderName::from_static("host")));
305 }
306
307 #[test]
308 fn test_build_backend_uri() {
309 let entry = ProxySiteEntry {
310 site_id: 1,
311 domain: "test.example.com".into(),
312 proxy_type: "basic".into(),
313 backend_url: url::Url::parse("http://localhost:3000").unwrap(),
314 config: ProxySiteConfig::default(),
315 };
316 let uri = "/api/test?foo=bar".parse::<Uri>().unwrap();
317 let result = build_backend_uri(&entry, &uri).unwrap();
318 assert_eq!(result.to_string(), "http://localhost:3000/api/test?foo=bar");
319 }
320
321 #[test]
322 fn test_build_backend_uri_root_path() {
323 let entry = ProxySiteEntry {
324 site_id: 1,
325 domain: "test.example.com".into(),
326 proxy_type: "basic".into(),
327 backend_url: url::Url::parse("http://localhost:3000").unwrap(),
328 config: ProxySiteConfig::default(),
329 };
330 let uri = "/".parse::<Uri>().unwrap();
331 let result = build_backend_uri(&entry, &uri).unwrap();
332 assert_eq!(result.to_string(), "http://localhost:3000/");
333 }
334
335 #[test]
336 fn test_build_backend_uri_with_path_prefix() {
337 let entry = ProxySiteEntry {
338 site_id: 1,
339 domain: "test.example.com".into(),
340 proxy_type: "basic".into(),
341 backend_url: url::Url::parse("http://backend:3000/a/").unwrap(),
342 config: ProxySiteConfig::default(),
343 };
344
345 let uri = "/".parse::<Uri>().unwrap();
347 let result = build_backend_uri(&entry, &uri).unwrap();
348 assert_eq!(result.to_string(), "http://backend:3000/a/");
349
350 let uri = "/foo".parse::<Uri>().unwrap();
352 let result = build_backend_uri(&entry, &uri).unwrap();
353 assert_eq!(result.to_string(), "http://backend:3000/a/foo");
354
355 let uri = "/api/test?key=val".parse::<Uri>().unwrap();
357 let result = build_backend_uri(&entry, &uri).unwrap();
358 assert_eq!(result.to_string(), "http://backend:3000/a/api/test?key=val");
359 }
360}
361
362