1use axum::http::{header, HeaderMap, HeaderName, HeaderValue, Uri};
4use hyper::body::Incoming;
5use hyper_util::{client::legacy::Client, rt::TokioExecutor};
6use std::sync::Arc;
7use std::time::Duration;
8
9use crate::prelude::*;
10use crate::ProxySiteEntry;
11
12const HOP_BY_HOP_HEADERS: &[&str] = &[
14 "connection",
15 "keep-alive",
16 "proxy-authenticate",
17 "proxy-authorization",
18 "te",
19 "trailers",
20 "transfer-encoding",
21];
22
23fn is_hop_by_hop(name: &HeaderName) -> bool {
25 HOP_BY_HOP_HEADERS.iter().any(|h| name.as_str().eq_ignore_ascii_case(h))
26}
27
28fn is_websocket_upgrade(headers: &HeaderMap) -> bool {
30 headers
31 .get(header::UPGRADE)
32 .and_then(|v| v.to_str().ok())
33 .map(|v| v.eq_ignore_ascii_case("websocket"))
34 .unwrap_or(false)
35}
36
37fn build_backend_uri(entry: &ProxySiteEntry, original_uri: &Uri) -> ClResult<Uri> {
39 let mut backend = entry.backend_url.clone();
40 let combined_path = format!("{}{}", backend.path().trim_end_matches('/'), original_uri.path());
41 backend.set_path(&combined_path);
42 backend.set_query(original_uri.query());
43 debug!("Proxy backend URI: {} (combined_path={:?})", backend.as_str(), combined_path);
44 backend
45 .as_str()
46 .parse::<Uri>()
47 .map_err(|e| Error::Internal(format!("failed to build backend URI: {}", e)))
48}
49
50fn copy_headers(src: &HeaderMap, dst: &mut HeaderMap, is_websocket: bool) {
52 for (name, value) in src.iter() {
53 if is_hop_by_hop(name) {
55 if is_websocket && name == header::UPGRADE {
56 dst.insert(name.clone(), value.clone());
57 }
58 continue;
59 }
60 dst.append(name.clone(), value.clone());
61 }
62}
63
64pub async fn handle_proxy_request(
66 entry: Arc<ProxySiteEntry>,
67 req: hyper::Request<Incoming>,
68 peer_addr: &str,
69) -> Result<hyper::Response<Incoming>, Error> {
70 let is_ws = is_websocket_upgrade(req.headers()) && entry.config.websocket.unwrap_or(true);
71
72 if is_ws {
73 return handle_websocket_proxy(entry, req, peer_addr).await;
74 }
75
76 let backend_uri = build_backend_uri(&entry, req.uri())?;
77
78 let mut backend_headers = HeaderMap::new();
80 copy_headers(req.headers(), &mut backend_headers, false);
81
82 let preserve_host = entry.config.preserve_host.unwrap_or(true);
84 if preserve_host {
85 if let Some(host) = req.headers().get(header::HOST) {
87 backend_headers.insert(header::HOST, host.clone());
88 }
89 } else if let Some(host) = entry.backend_url.host_str() {
90 let host_val = if let Some(port) = entry.backend_url.port() {
92 format!("{}:{}", host, port)
93 } else {
94 host.to_string()
95 };
96 if let Ok(hv) = HeaderValue::from_str(&host_val) {
97 backend_headers.insert(header::HOST, hv);
98 }
99 }
100
101 let forward_headers = if entry.proxy_type.as_ref() == "basic" {
103 true
104 } else {
105 entry.config.forward_headers.unwrap_or(true)
106 };
107 if forward_headers {
108 if let Ok(hv) = HeaderValue::from_str(peer_addr) {
109 backend_headers.insert(HeaderName::from_static("x-forwarded-for"), hv.clone());
110 backend_headers.insert(HeaderName::from_static("x-real-ip"), hv);
111 }
112 backend_headers.insert(
113 HeaderName::from_static("x-forwarded-proto"),
114 HeaderValue::from_static("https"),
115 );
116 if let Ok(hv) = HeaderValue::from_str(&entry.domain) {
117 backend_headers.insert(HeaderName::from_static("x-forwarded-host"), hv);
118 }
119 }
120
121 if let Some(custom_headers) = &entry.config.custom_headers {
123 for (name, value) in custom_headers {
124 if let (Ok(hn), Ok(hv)) =
125 (HeaderName::from_bytes(name.as_bytes()), HeaderValue::from_str(value))
126 {
127 backend_headers.insert(hn, hv);
128 }
129 }
130 }
131
132 let method = req.method().clone();
134 let body = req.into_body();
135
136 let mut backend_req = hyper::Request::builder().method(method).uri(backend_uri);
137
138 if let Some(headers) = backend_req.headers_mut() {
139 *headers = backend_headers;
140 }
141
142 let backend_req = backend_req
143 .body(body)
144 .map_err(|e| Error::Internal(format!("failed to build backend request: {}", e)))?;
145
146 let connect_timeout =
148 Duration::from_secs(entry.config.connect_timeout_secs.unwrap_or(5) as u64);
149 let read_timeout = Duration::from_secs(entry.config.read_timeout_secs.unwrap_or(30) as u64);
150
151 let scheme = entry.backend_url.scheme();
153 match send_backend_request(scheme, connect_timeout, read_timeout, backend_req).await {
154 Ok(mut backend_resp) => {
155 let headers_to_remove: Vec<HeaderName> = backend_resp
157 .headers()
158 .keys()
159 .filter(|name| is_hop_by_hop(name))
160 .cloned()
161 .collect();
162 for name in headers_to_remove {
163 backend_resp.headers_mut().remove(&name);
164 }
165 Ok(backend_resp)
166 }
167 Err(e @ Error::Timeout) => {
168 warn!("Proxy backend timeout for {}", entry.domain);
169 Err(e)
170 }
171 Err(e) => {
172 warn!("Proxy backend error for {}: {}", entry.domain, e);
173 Err(e)
174 }
175 }
176}
177
178async fn handle_websocket_proxy(
180 entry: Arc<ProxySiteEntry>,
181 req: hyper::Request<Incoming>,
182 peer_addr: &str,
183) -> Result<hyper::Response<Incoming>, Error> {
184 let backend_uri = build_backend_uri(&entry, req.uri())?;
187
188 let mut backend_headers = HeaderMap::new();
189 for (name, value) in req.headers().iter() {
191 if is_hop_by_hop(name) && name != header::UPGRADE {
192 continue;
193 }
194 backend_headers.append(name.clone(), value.clone());
195 }
196
197 let preserve_host = entry.config.preserve_host.unwrap_or(true);
199 if !preserve_host {
200 if let Some(host) = entry.backend_url.host_str() {
201 let host_val = if let Some(port) = entry.backend_url.port() {
202 format!("{}:{}", host, port)
203 } else {
204 host.to_string()
205 };
206 if let Ok(hv) = HeaderValue::from_str(&host_val) {
207 backend_headers.insert(header::HOST, hv);
208 }
209 }
210 }
211
212 let forward_headers = if entry.proxy_type.as_ref() == "basic" {
214 true
215 } else {
216 entry.config.forward_headers.unwrap_or(true)
217 };
218 if forward_headers {
219 if let Ok(hv) = HeaderValue::from_str(peer_addr) {
220 backend_headers.insert(HeaderName::from_static("x-forwarded-for"), hv);
221 }
222 backend_headers.insert(
223 HeaderName::from_static("x-forwarded-proto"),
224 HeaderValue::from_static("https"),
225 );
226 }
227
228 backend_headers.insert(header::CONNECTION, HeaderValue::from_static("Upgrade"));
230
231 let method = req.method().clone();
232 let body = req.into_body();
233
234 let mut backend_req = hyper::Request::builder().method(method).uri(backend_uri);
235
236 if let Some(headers) = backend_req.headers_mut() {
237 *headers = backend_headers;
238 }
239
240 let backend_req = backend_req
241 .body(body)
242 .map_err(|e| Error::Internal(format!("failed to build ws backend request: {}", e)))?;
243
244 let connect_timeout =
246 Duration::from_secs(entry.config.connect_timeout_secs.unwrap_or(5) as u64);
247
248 let scheme = entry.backend_url.scheme();
249 match send_backend_request(scheme, connect_timeout, connect_timeout, backend_req).await {
250 Ok(backend_resp) => Ok(backend_resp),
251 Err(e @ Error::Timeout) => {
252 warn!("WebSocket proxy backend timeout for {}", entry.domain);
253 Err(e)
254 }
255 Err(e) => {
256 warn!("WebSocket proxy backend error for {}: {}", entry.domain, e);
257 Err(e)
258 }
259 }
260}
261
262async fn send_backend_request(
264 scheme: &str,
265 connect_timeout: Duration,
266 timeout: Duration,
267 req: hyper::Request<Incoming>,
268) -> Result<hyper::Response<Incoming>, Error> {
269 let result = if scheme == "https" {
270 let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
271 .with_native_roots()
272 .map_err(|_| Error::ConfigError("no native root CA certificates found".into()))?
273 .https_only()
274 .enable_http1()
275 .build();
276 let client: Client<_, Incoming> = Client::builder(TokioExecutor::new())
277 .pool_idle_timeout(connect_timeout)
278 .build(https_connector);
279 tokio::time::timeout(timeout, client.request(req)).await
280 } else {
281 let http_connector = hyper_util::client::legacy::connect::HttpConnector::new();
282 let client: Client<_, Incoming> = Client::builder(TokioExecutor::new())
283 .pool_idle_timeout(connect_timeout)
284 .build(http_connector);
285 tokio::time::timeout(timeout, client.request(req)).await
286 };
287 match result {
288 Ok(Ok(resp)) => Ok(resp),
289 Ok(Err(_)) => Err(Error::NetworkError("bad gateway".into())),
290 Err(_) => Err(Error::Timeout),
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_is_hop_by_hop() {
300 assert!(is_hop_by_hop(&HeaderName::from_static("connection")));
301 assert!(is_hop_by_hop(&HeaderName::from_static("keep-alive")));
302 assert!(is_hop_by_hop(&HeaderName::from_static("transfer-encoding")));
303 assert!(!is_hop_by_hop(&HeaderName::from_static("content-type")));
304 assert!(!is_hop_by_hop(&HeaderName::from_static("host")));
305 }
306
307 #[test]
308 fn test_build_backend_uri() {
309 let entry = ProxySiteEntry {
310 site_id: 1,
311 domain: "test.example.com".into(),
312 proxy_type: "basic".into(),
313 backend_url: url::Url::parse("http://localhost:3000").unwrap(),
314 config: Default::default(),
315 };
316 let uri = "/api/test?foo=bar".parse::<Uri>().unwrap();
317 let result = build_backend_uri(&entry, &uri).unwrap();
318 assert_eq!(result.to_string(), "http://localhost:3000/api/test?foo=bar");
319 }
320
321 #[test]
322 fn test_build_backend_uri_root_path() {
323 let entry = ProxySiteEntry {
324 site_id: 1,
325 domain: "test.example.com".into(),
326 proxy_type: "basic".into(),
327 backend_url: url::Url::parse("http://localhost:3000").unwrap(),
328 config: Default::default(),
329 };
330 let uri = "/".parse::<Uri>().unwrap();
331 let result = build_backend_uri(&entry, &uri).unwrap();
332 assert_eq!(result.to_string(), "http://localhost:3000/");
333 }
334
335 #[test]
336 fn test_build_backend_uri_with_path_prefix() {
337 let entry = ProxySiteEntry {
338 site_id: 1,
339 domain: "test.example.com".into(),
340 proxy_type: "basic".into(),
341 backend_url: url::Url::parse("http://backend:3000/a/").unwrap(),
342 config: Default::default(),
343 };
344
345 let uri = "/".parse::<Uri>().unwrap();
347 let result = build_backend_uri(&entry, &uri).unwrap();
348 assert_eq!(result.to_string(), "http://backend:3000/a/");
349
350 let uri = "/foo".parse::<Uri>().unwrap();
352 let result = build_backend_uri(&entry, &uri).unwrap();
353 assert_eq!(result.to_string(), "http://backend:3000/a/foo");
354
355 let uri = "/api/test?key=val".parse::<Uri>().unwrap();
357 let result = build_backend_uri(&entry, &uri).unwrap();
358 assert_eq!(result.to_string(), "http://backend:3000/a/api/test?key=val");
359 }
360}
361
362