1pub mod cgi;
2pub mod circuit_breaker;
3pub mod dns_upstream;
4pub mod fastcgi;
5pub mod forward_proxy;
6pub mod health;
7pub mod lb;
8pub mod scgi;
9pub mod srv_upstream;
10#[cfg(unix)]
11pub mod unix_upstream;
12pub mod upstream;
13pub mod websocket;
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use http::{Request, Response, Uri};
19use http_body_util::BodyExt;
20use regex::Regex;
21use tracing::{debug, warn};
22
23use self::health::{HealthChecker, PassiveHealthChecker};
24use self::lb::*;
25use self::upstream::UpstreamPool;
26use crate::config::{LbPolicy, ProxyConfig};
27use crate::{Body, ProxyError, goals};
28
29pub struct ReverseProxy {
32 pool: Arc<UpstreamPool>,
33 lb: Box<dyn LoadBalancer>,
34 headers_up: Vec<(String, String)>,
35 headers_down: Vec<(String, String)>,
36 retries: u32,
37 _health_checker: Option<HealthChecker>,
39 passive_health: Option<Arc<PassiveHealthChecker>>,
41 error_pages: HashMap<u16, String>,
43 headers_up_replace: Vec<(String, Regex, String)>,
46 sanitize_uri: bool,
49}
50
51impl ReverseProxy {
52 pub fn new(config: &ProxyConfig) -> Self {
53 let pool = Arc::new(UpstreamPool::from_config(config));
54
55 let weights: Vec<u32> = config.upstreams.iter().map(|u| u.weight).collect();
57
58 let lb: Box<dyn LoadBalancer> = match config.lb {
59 LbPolicy::RoundRobin => Box::new(RoundRobinLb::new()),
60 LbPolicy::Random => Box::new(RandomLb::new()),
61 LbPolicy::WeightedRoundRobin => Box::new(WeightedRoundRobinLb::new(&weights)),
62 LbPolicy::IpHash => Box::new(IpHashLb::new()),
63 LbPolicy::LeastConn => Box::new(LeastConnLb::new()),
64 LbPolicy::UriHash => Box::new(UriHashLb::new()),
65 LbPolicy::HeaderHash => {
66 let name = config
67 .lb_header
68 .clone()
69 .unwrap_or_else(|| "X-Forwarded-For".to_string());
70 Box::new(HeaderHashLb::new(name))
71 }
72 LbPolicy::CookieHash => {
73 let name = config
74 .lb_cookie
75 .clone()
76 .unwrap_or_else(|| "session".to_string());
77 Box::new(CookieHashLb::new(name))
78 }
79 LbPolicy::First => Box::new(FirstLb::new()),
80 LbPolicy::TwoRandomChoices => Box::new(TwoRandomChoicesLb::new()),
81 };
82
83 let headers_up: Vec<(String, String)> = config
84 .headers_up
85 .iter()
86 .map(|(k, v)| (k.clone(), v.clone()))
87 .collect();
88 let headers_down: Vec<(String, String)> = config
89 .headers_down
90 .iter()
91 .map(|(k, v)| (k.clone(), v.clone()))
92 .collect();
93
94 let health_checker = config
96 .health_check
97 .as_ref()
98 .map(|hc| HealthChecker::start(Arc::clone(&pool), hc));
99
100 let passive_health = config
102 .passive_health
103 .as_ref()
104 .map(|ph| Arc::new(PassiveHealthChecker::new(pool.len(), ph)));
105
106 let headers_up_replace: Vec<(String, Regex, String)> = config
108 .headers_up_replace
109 .iter()
110 .filter_map(|(name, pattern, replacement)| match Regex::new(pattern) {
111 Ok(re) => Some((name.clone(), re, replacement.clone())),
112 Err(e) => {
113 warn!(
114 header = name.as_str(),
115 pattern = pattern.as_str(),
116 error = %e,
117 "invalid regex in header-up-replace, skipping"
118 );
119 None
120 }
121 })
122 .collect();
123
124 Self {
125 pool,
126 lb,
127 headers_up,
128 headers_down,
129 retries: config.retries,
130 _health_checker: health_checker,
131 passive_health,
132 error_pages: config.error_pages.clone(),
133 headers_up_replace,
134 sanitize_uri: config.sanitize_uri,
135 }
136 }
137}
138
139#[salvo::async_trait]
140impl salvo::Handler for ReverseProxy {
141 async fn handle(
142 &self,
143 req: &mut salvo::Request,
144 _depot: &mut salvo::Depot,
145 res: &mut salvo::Response,
146 ctrl: &mut salvo::FlowCtrl,
147 ) {
148 let client_addr = crate::hoops::client_addr(req);
149 let request = match goals::strip_request(req) {
150 Ok(r) => r,
151 Err(e) => {
152 goals::merge_response(res, e.into_response());
153 ctrl.skip_rest();
154 return;
155 }
156 };
157 let response = self
158 .proxy(request, client_addr)
159 .await
160 .unwrap_or_else(|e| e.into_response());
161 goals::merge_response(res, response);
162 ctrl.skip_rest();
163 }
164}
165
166impl ReverseProxy {
167 async fn proxy(
168 &self,
169 request: Request<Body>,
170 client_addr: std::net::SocketAddr,
171 ) -> Result<Response<Body>, ProxyError> {
172 if websocket::is_websocket_upgrade(&request) {
176 debug!(client = %client_addr, "detected WebSocket upgrade request");
177
178 let ws_lb_ctx = LbContext {
180 client_addr,
181 uri: request
182 .uri()
183 .path_and_query()
184 .map(|pq| pq.as_str().to_string())
185 .unwrap_or_else(|| "/".to_string()),
186 headers: request.headers().clone(),
187 };
188
189 let backend_idx = self
190 .lb
191 .select(&self.pool, &ws_lb_ctx)
192 .ok_or(ProxyError::NoUpstream)?;
193 let backend = &self.pool.backends[backend_idx];
194 let _conn_guard = self.pool.acquire_conn(backend_idx);
195
196 return websocket::proxy_websocket(request, &backend.addr).await;
197 }
198
199 let lb_ctx = LbContext {
201 client_addr,
202 uri: request
203 .uri()
204 .path_and_query()
205 .map(|pq| pq.as_str().to_string())
206 .unwrap_or_else(|| "/".to_string()),
207 headers: request.headers().clone(),
208 };
209
210 if let Some(max_conns) = self.pool.max_connections {
212 let total = self.pool.total_active_conns();
213 if total >= max_conns {
214 warn!(
215 limit = max_conns,
216 active = total,
217 "connection limit exceeded, returning 503"
218 );
219 return Response::builder()
220 .status(http::StatusCode::SERVICE_UNAVAILABLE)
221 .body(crate::full_body(
222 "Service Unavailable: connection limit exceeded",
223 ))
224 .map_err(|e| ProxyError::Internal(e.to_string()));
225 }
226 }
227
228 let (parts, body) = request.into_parts();
230 let body_bytes = body
231 .collect()
232 .await
233 .map_err(|e| ProxyError::Internal(format!("failed to buffer body: {e}")))?
234 .to_bytes();
235
236 let max_attempts = 1 + self.retries;
237 let mut last_failed_idx: Option<usize> = None;
238 let mut last_error: Option<ProxyError> = None;
239
240 for attempt in 0..max_attempts {
241 let backend_idx = {
243 let idx = self.lb.select(&self.pool, &lb_ctx);
244 match idx {
245 Some(i) if last_failed_idx == Some(i) && self.pool.len() > 1 => {
246 self.lb.select(&self.pool, &lb_ctx)
248 }
249 other => other,
250 }
251 };
252
253 let backend_idx = match backend_idx {
254 Some(i) => i,
255 None => {
256 return Err(last_error.unwrap_or(ProxyError::NoUpstream));
257 }
258 };
259
260 let backend = &self.pool.backends[backend_idx];
261
262 let mut req_parts = parts.clone();
264
265 if self.sanitize_uri {
267 let raw_pq = req_parts
268 .uri
269 .path_and_query()
270 .map(|pq| pq.as_str().to_string())
271 .unwrap_or_else(|| "/".to_string());
272 let (raw_path, raw_query) = if let Some(pos) = raw_pq.find('?') {
273 (&raw_pq[..pos], Some(&raw_pq[pos + 1..]))
274 } else {
275 (raw_pq.as_str(), None)
276 };
277 let sanitized_path = sanitize_path(raw_path);
278 let sanitized_pq = match raw_query {
279 Some(q) if !q.is_empty() => format!("{sanitized_path}?{q}"),
280 _ => sanitized_path,
281 };
282 if let Ok(new_uri) = sanitized_pq.parse::<http::uri::PathAndQuery>() {
283 let mut builder = Uri::builder();
285 if let Some(scheme) = req_parts.uri.scheme() {
286 builder = builder.scheme(scheme.clone());
287 }
288 if let Some(authority) = req_parts.uri.authority() {
289 builder = builder.authority(authority.clone());
290 }
291 builder = builder.path_and_query(new_uri);
292 if let Ok(u) = builder.build() {
293 req_parts.uri = u;
294 }
295 }
296 }
297
298 let scheme =
301 if backend.addr.starts_with("https://") || backend.addr.starts_with("http://") {
302 ""
303 } else {
304 "http://"
305 };
306 let upstream_uri = format!(
307 "{}{}{}",
308 scheme,
309 backend.addr,
310 req_parts
311 .uri
312 .path_and_query()
313 .map(|pq| pq.as_str())
314 .unwrap_or("/")
315 );
316 req_parts.uri = match upstream_uri.parse::<Uri>() {
317 Ok(u) => u,
318 Err(e) => {
319 return Err(ProxyError::Internal(format!("invalid upstream URI: {e}")));
320 }
321 };
322
323 if let Ok(hv) = backend.addr.parse() {
325 req_parts.headers.insert(http::header::HOST, hv);
326 }
327
328 for (name, value) in &self.headers_up {
330 if let Some(hdr_name) = name.strip_prefix('-') {
331 if let Ok(hn) = hdr_name.parse::<http::header::HeaderName>() {
332 req_parts.headers.remove(hn);
333 }
334 } else {
335 let expanded = value.replace("{client_ip}", &client_addr.ip().to_string());
336 if let (Ok(hn), Ok(hv)) = (
337 name.parse::<http::header::HeaderName>(),
338 expanded.parse::<http::header::HeaderValue>(),
339 ) {
340 req_parts.headers.insert(hn, hv);
341 }
342 }
343 }
344
345 for (name, re, replacement) in &self.headers_up_replace {
347 if let Ok(hn) = name.parse::<http::header::HeaderName>()
348 && let Some(existing) = req_parts.headers.get(&hn)
349 && let Ok(existing_str) = existing.to_str()
350 {
351 let new_value = re.replace_all(existing_str, replacement.as_str());
352 if let Ok(hv) = new_value.as_ref().parse::<http::header::HeaderValue>() {
353 req_parts.headers.insert(hn, hv);
354 }
355 }
356 }
357
358 let req_body = crate::full_body(body_bytes.clone());
360 let upstream_req = Request::from_parts(req_parts, req_body);
361
362 debug!(
363 upstream = %backend.addr,
364 attempt = attempt + 1,
365 "forwarding request"
366 );
367
368 let _conn_guard = self.pool.acquire_conn(backend_idx);
370
371 let result = if is_unix_socket(&backend.addr) {
374 #[cfg(unix)]
375 {
376 let path = unix_socket_path(&backend.addr);
377 send_via_unix(path, upstream_req).await.map(|r| {
378 r.map(|b| {
379 let b: Body = b.map_err(|e| -> crate::BoxError { Box::new(e) }).boxed();
380 b
381 })
382 })
383 }
384 #[cfg(not(unix))]
385 {
386 let _ = upstream_req;
387 Err(ProxyError::Internal(
388 "Unix domain socket upstreams are not supported on this platform".into(),
389 ))
390 }
391 } else {
392 self.pool
393 .client
394 .request(upstream_req)
395 .await
396 .map_err(ProxyError::Client)
397 .map(|r| {
398 r.map(|b| {
399 let b: Body = b.map_err(|e| -> crate::BoxError { Box::new(e) }).boxed();
400 b
401 })
402 })
403 };
404
405 match result {
406 Ok(resp) => {
407 let (mut resp_parts, resp_body) = resp.into_parts();
408
409 if resp_parts.status.is_server_error()
411 && let Some(ref ph) = self.passive_health
412 {
413 ph.record_failure(backend_idx, &self.pool).await;
414 }
415
416 if resp_parts.status.is_server_error() && attempt + 1 < max_attempts {
418 warn!(
419 upstream = %backend.addr,
420 status = %resp_parts.status,
421 attempt = attempt + 1,
422 "upstream returned server error, retrying"
423 );
424 last_failed_idx = Some(backend_idx);
425 last_error = Some(ProxyError::Internal(format!(
426 "upstream {} returned {}",
427 backend.addr, resp_parts.status
428 )));
429 continue;
430 }
431
432 for (name, value) in &self.headers_down {
436 if let Some(hdr_name) = name.strip_prefix('-') {
437 if let Ok(hn) = hdr_name.parse::<http::header::HeaderName>() {
438 resp_parts.headers.remove(hn);
439 }
440 } else if let (Ok(hn), Ok(hv)) = (
441 name.parse::<http::header::HeaderName>(),
442 value.parse::<http::header::HeaderValue>(),
443 ) {
444 resp_parts.headers.insert(hn, hv);
445 }
446 }
447
448 if let Some(ref ph) = self.passive_health {
450 ph.maybe_recover(&self.pool).await;
451 }
452
453 let status_code = resp_parts.status.as_u16();
456 if let Some(error_body) = self.error_pages.get(&status_code) {
457 debug!(
458 status = status_code,
459 "intercepting upstream error with configured error page"
460 );
461 return Ok(Response::from_parts(
462 resp_parts,
463 crate::full_body(error_body.clone()),
464 ));
465 }
466
467 return Ok(Response::from_parts(resp_parts, resp_body));
468 }
469 Err(e) => {
470 if let Some(ref ph) = self.passive_health {
472 ph.record_failure(backend_idx, &self.pool).await;
473 }
474
475 if attempt + 1 < max_attempts {
476 warn!(
477 upstream = %backend.addr,
478 error = %e,
479 attempt = attempt + 1,
480 "upstream request failed, retrying"
481 );
482 last_failed_idx = Some(backend_idx);
483 last_error = Some(e);
484 continue;
485 }
486
487 return Err(e);
488 }
489 }
490 }
491
492 Err(last_error.unwrap_or(ProxyError::NoUpstream))
494 }
495}
496
497fn sanitize_path(path: &str) -> String {
506 let mut segments: Vec<&str> = Vec::new();
508 for segment in path.split('/') {
509 match segment {
510 "" | "." => {
511 }
513 ".." => {
514 segments.pop();
516 }
517 s => {
518 segments.push(s);
519 }
520 }
521 }
522 let mut result = String::with_capacity(path.len());
523 result.push('/');
524 result.push_str(&segments.join("/"));
525 result
526}
527
528fn is_unix_socket(addr: &str) -> bool {
535 addr.starts_with("unix:") || addr.starts_with('/')
536}
537
538#[cfg(unix)]
540fn unix_socket_path(addr: &str) -> &str {
541 addr.strip_prefix("unix:").unwrap_or(addr)
542}
543
544#[cfg(unix)]
549async fn send_via_unix(
550 socket_path: &str,
551 request: http::Request<Body>,
552) -> Result<http::Response<hyper::body::Incoming>, ProxyError> {
553 let stream = tokio::net::UnixStream::connect(socket_path)
554 .await
555 .map_err(|e| ProxyError::Internal(format!("unix socket connect to {socket_path}: {e}")))?;
556 let io = hyper_util::rt::TokioIo::new(stream);
557 let (mut sender, conn) = hyper::client::conn::http1::handshake(io)
558 .await
559 .map_err(ProxyError::Hyper)?;
560 tokio::spawn(async move {
561 let _ = conn.await;
562 });
563 sender
564 .send_request(request)
565 .await
566 .map_err(ProxyError::Hyper)
567}