1#![allow(clippy::let_unit_value, clippy::clone_on_copy, clippy::unit_arg)]
3
4use crate::acme::ChallengeStore;
5use crate::app::AppManager;
6use crate::circuit_breaker::SharedCircuitBreaker;
7use crate::config::ConfigManager;
8use crate::metrics::SharedMetrics;
9use crate::shutdown::ShutdownCoordinator;
10use anyhow::Result;
11use bytes::Bytes;
12use http_body_util::BodyExt;
13use hyper::body::Incoming;
14use hyper::header::HeaderValue;
15use hyper::service::service_fn;
16use hyper::Request;
17use hyper::Response;
18use hyper_util::client::legacy::connect::HttpConnector;
19use hyper_util::client::legacy::Client;
20use hyper_util::rt::TokioExecutor;
21use hyper_util::rt::TokioIo;
22use socket2::{Domain, Protocol, Socket, Type};
23use std::net::SocketAddr;
24use std::sync::Arc;
25use std::time::Duration;
26use tokio::io::AsyncWriteExt;
27use tokio::net::{TcpListener, TcpStream};
28use tokio_rustls::TlsAcceptor;
29
30#[cfg(feature = "scripting")]
31use crate::scripting::{LuaEngine, LuaRequest, RequestHookResult, RouteHookResult};
32
33type ClientType = Client<HttpConnector, Incoming>;
34type BoxBody = http_body_util::combinators::BoxBody<Bytes, std::convert::Infallible>;
35
36#[cfg(feature = "scripting")]
37type OptionalLuaEngine = Option<LuaEngine>;
38#[cfg(not(feature = "scripting"))]
39type OptionalLuaEngine = ();
40
41fn record_app_metrics(
43 metrics: &SharedMetrics,
44 app_manager: &Option<Arc<AppManager>>,
45 target_url: &str,
46 bytes_in: u64,
47 bytes_out: u64,
48 status: u16,
49 duration: Duration,
50) {
51 if let Some(ref manager) = app_manager {
52 if let Ok(url) = url::Url::parse(target_url) {
53 if let Some(port) = url.port() {
54 if let Some(app_name) = futures::executor::block_on(manager.get_app_name(port)) {
55 metrics.record_app_request(&app_name, bytes_in, bytes_out, status, duration);
56 }
57 }
58 }
59 }
60}
61
62static X_FORWARDED_FOR_VALUE: std::sync::LazyLock<HeaderValue> =
64 std::sync::LazyLock::new(|| HeaderValue::from_static("127.0.0.1"));
65
66fn create_listener(addr: SocketAddr) -> Result<TcpListener> {
67 let domain = if addr.is_ipv4() {
68 Domain::IPV4
69 } else {
70 Domain::IPV6
71 };
72 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
73 socket.set_reuse_address(true)?;
74 socket.set_reuse_port(true)?;
75 socket.set_nonblocking(true)?;
76 socket.bind(&addr.into())?;
77 socket.listen(8192)?;
78 let std_listener: std::net::TcpListener = socket.into();
79 Ok(TcpListener::from_std(std_listener)?)
80}
81
82fn create_client() -> ClientType {
83 let exec = TokioExecutor::new();
84 let mut connector = HttpConnector::new();
85 connector.set_nodelay(true);
86 connector.set_keepalive(Some(Duration::from_secs(30)));
87 connector.set_connect_timeout(Some(Duration::from_secs(5)));
88 Client::builder(exec)
89 .pool_max_idle_per_host(256)
90 .pool_idle_timeout(Duration::from_secs(60))
91 .build(connector)
92}
93
94pub struct ProxyServer {
95 config: Arc<ConfigManager>,
96 shutdown: ShutdownCoordinator,
97 tls_acceptor: Option<TlsAcceptor>,
98 https_addr: Option<SocketAddr>,
99 metrics: SharedMetrics,
100 challenge_store: ChallengeStore,
101 lua_engine: OptionalLuaEngine,
102 circuit_breaker: SharedCircuitBreaker,
103 app_manager: Option<Arc<AppManager>>,
104}
105
106impl ProxyServer {
107 pub fn new(
108 config: Arc<ConfigManager>,
109 shutdown: ShutdownCoordinator,
110 metrics: SharedMetrics,
111 challenge_store: ChallengeStore,
112 lua_engine: OptionalLuaEngine,
113 circuit_breaker: SharedCircuitBreaker,
114 app_manager: Option<Arc<AppManager>>,
115 ) -> Result<Self> {
116 Ok(Self {
117 config,
118 shutdown,
119 tls_acceptor: None,
120 https_addr: None,
121 metrics,
122 challenge_store,
123 lua_engine,
124 circuit_breaker,
125 app_manager,
126 })
127 }
128
129 #[allow(clippy::too_many_arguments)]
130 pub fn with_https(
131 config: Arc<ConfigManager>,
132 shutdown: ShutdownCoordinator,
133 tls_acceptor: TlsAcceptor,
134 https_addr: SocketAddr,
135 metrics: SharedMetrics,
136 challenge_store: ChallengeStore,
137 lua_engine: OptionalLuaEngine,
138 circuit_breaker: SharedCircuitBreaker,
139 app_manager: Option<Arc<AppManager>>,
140 ) -> Result<Self> {
141 Ok(Self {
142 config,
143 shutdown,
144 tls_acceptor: Some(tls_acceptor),
145 https_addr: Some(https_addr),
146 metrics,
147 challenge_store,
148 lua_engine,
149 circuit_breaker,
150 app_manager,
151 })
152 }
153
154 pub async fn run(&self) -> Result<()> {
155 let cfg = self.config.get_config();
156 let http_addr: SocketAddr = cfg.server.bind.parse()?;
157 let https_addr = self.https_addr;
158
159 let has_https = https_addr.is_some();
160 let num_listeners = std::thread::available_parallelism()
161 .map(|n| n.get())
162 .unwrap_or(4);
163
164 let app_manager = self.app_manager.clone();
167 for i in 0..num_listeners {
168 let config_clone = self.config.clone();
169 let shutdown_clone = self.shutdown.clone();
170 let metrics_clone = self.metrics.clone();
171 let challenge_store_clone = self.challenge_store.clone();
172 let lua_clone = self.lua_engine.clone();
173 let cb_clone = self.circuit_breaker.clone();
174 let am_clone = app_manager.clone();
175
176 tokio::spawn(async move {
177 if let Err(e) = run_http_server(
178 http_addr,
179 config_clone,
180 shutdown_clone,
181 metrics_clone,
182 challenge_store_clone,
183 lua_clone,
184 cb_clone,
185 am_clone,
186 )
187 .await
188 {
189 tracing::error!("HTTP/1.1 server error (listener {}): {}", i, e);
190 }
191 });
192 }
193
194 if let Some(https_addr) = https_addr {
195 for i in 0..num_listeners {
196 let config_clone = self.config.clone();
197 let shutdown_clone = self.shutdown.clone();
198 let acceptor = self.tls_acceptor.as_ref().unwrap().clone();
199 let metrics_clone = self.metrics.clone();
200 let challenge_store_clone = self.challenge_store.clone();
201 let lua_clone = self.lua_engine.clone();
202 let cb_clone = self.circuit_breaker.clone();
203 let am_clone = app_manager.clone();
204
205 tokio::spawn(async move {
206 if let Err(e) = run_https_server(
207 https_addr,
208 config_clone,
209 shutdown_clone,
210 acceptor,
211 metrics_clone,
212 challenge_store_clone,
213 lua_clone,
214 cb_clone,
215 am_clone,
216 )
217 .await
218 {
219 tracing::error!("HTTPS/2 server error (listener {}): {}", i, e);
220 }
221 });
222 }
223 }
224
225 tracing::info!(
226 "HTTP/1.1 server listening on {} ({} accept loops)",
227 http_addr,
228 num_listeners
229 );
230 if has_https {
231 tracing::info!(
232 "HTTPS/2 server listening on {} ({} accept loops)",
233 https_addr.unwrap(),
234 num_listeners
235 );
236 }
237
238 loop {
239 if self.shutdown.is_shutting_down() {
240 tracing::info!("Shutting down servers...");
241 break;
242 }
243 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
244 }
245
246 Ok(())
247 }
248}
249
250#[allow(clippy::too_many_arguments)]
251async fn run_http_server(
252 addr: SocketAddr,
253 config: Arc<ConfigManager>,
254 shutdown: ShutdownCoordinator,
255 metrics: SharedMetrics,
256 challenge_store: ChallengeStore,
257 lua_engine: OptionalLuaEngine,
258 circuit_breaker: SharedCircuitBreaker,
259 app_manager: Option<Arc<AppManager>>,
260) -> Result<()> {
261 let listener = create_listener(addr)?;
262 let client = create_client();
263
264 loop {
265 if shutdown.is_shutting_down() {
266 break;
267 }
268
269 match listener.accept().await {
270 Ok((stream, _)) => {
271 let _ = stream.set_nodelay(true);
272 let client = client.clone();
273 let config = config.clone();
274 let metrics = metrics.clone();
275 let cs = challenge_store.clone();
276 let lua = lua_engine.clone();
277 let cb = circuit_breaker.clone();
278 let am = app_manager.clone();
279 tokio::spawn(async move {
280 if let Err(e) =
281 handle_http11_connection(stream, client, config, metrics, cs, lua, cb, am)
282 .await
283 {
284 tracing::debug!("HTTP/1.1 connection error: {}", e);
285 }
286 });
287 }
288 Err(e) => {
289 tracing::error!("HTTP/1.1 accept error: {}", e);
290 }
291 }
292 }
293
294 Ok(())
295}
296
297#[allow(clippy::too_many_arguments)]
298async fn run_https_server(
299 addr: SocketAddr,
300 config: Arc<ConfigManager>,
301 shutdown: ShutdownCoordinator,
302 acceptor: TlsAcceptor,
303 metrics: SharedMetrics,
304 challenge_store: ChallengeStore,
305 lua_engine: OptionalLuaEngine,
306 circuit_breaker: SharedCircuitBreaker,
307 app_manager: Option<Arc<AppManager>>,
308) -> Result<()> {
309 let listener = create_listener(addr)?;
310 let client = create_client();
311
312 loop {
313 if shutdown.is_shutting_down() {
314 break;
315 }
316
317 match listener.accept().await {
318 Ok((stream, _)) => {
319 let _ = stream.set_nodelay(true);
320 let client = client.clone();
321 let config = config.clone();
322 let acceptor = acceptor.clone();
323 let metrics = metrics.clone();
324 let cs = challenge_store.clone();
325 let lua = lua_engine.clone();
326 let cb = circuit_breaker.clone();
327 let am = app_manager.clone();
328 tokio::spawn(async move {
329 match acceptor.accept(stream).await {
330 Ok(tls_stream) => {
331 metrics.inc_tls_connections();
332 if let Err(e) = handle_https2_connection(
333 tls_stream, client, config, metrics, cs, lua, cb, am,
334 )
335 .await
336 {
337 tracing::debug!("HTTPS/2 connection error: {}", e);
338 }
339 }
340 Err(e) => {
341 tracing::error!("TLS accept error: {}", e);
342 }
343 }
344 });
345 }
346 Err(e) => {
347 tracing::error!("HTTPS/2 accept error: {}", e);
348 }
349 }
350 }
351
352 Ok(())
353}
354
355#[allow(clippy::too_many_arguments)]
356async fn handle_http11_connection(
357 stream: tokio::net::TcpStream,
358 client: ClientType,
359 config: Arc<ConfigManager>,
360 metrics: SharedMetrics,
361 challenge_store: ChallengeStore,
362 lua_engine: OptionalLuaEngine,
363 circuit_breaker: SharedCircuitBreaker,
364 app_manager: Option<Arc<AppManager>>,
365) -> Result<()> {
366 let io = TokioIo::new(stream);
367 let svc = service_fn(move |req| {
368 handle_request(
369 req,
370 client.clone(),
371 config.clone(),
372 metrics.clone(),
373 challenge_store.clone(),
374 lua_engine.clone(),
375 circuit_breaker.clone(),
376 app_manager.clone(),
377 )
378 });
379
380 let conn = hyper::server::conn::http1::Builder::new()
381 .keep_alive(true)
382 .pipeline_flush(true)
383 .serve_connection(io, svc)
384 .with_upgrades();
385
386 if let Err(e) = conn.await {
387 tracing::debug!("HTTP/1.1 connection error: {}", e);
388 }
389
390 Ok(())
391}
392
393#[allow(clippy::too_many_arguments)]
394async fn handle_https2_connection(
395 stream: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
396 client: ClientType,
397 config: Arc<ConfigManager>,
398 metrics: SharedMetrics,
399 challenge_store: ChallengeStore,
400 lua_engine: OptionalLuaEngine,
401 circuit_breaker: SharedCircuitBreaker,
402 app_manager: Option<Arc<AppManager>>,
403) -> Result<()> {
404 let is_h2 = stream.get_ref().1.alpn_protocol() == Some(b"h2");
405
406 let io = TokioIo::new(stream);
407
408 if is_h2 {
409 let exec = TokioExecutor::new();
410 let svc = service_fn(move |req| {
411 handle_request(
412 req,
413 client.clone(),
414 config.clone(),
415 metrics.clone(),
416 challenge_store.clone(),
417 lua_engine.clone(),
418 circuit_breaker.clone(),
419 app_manager.clone(),
420 )
421 });
422 let conn = hyper::server::conn::http2::Builder::new(exec)
423 .initial_stream_window_size(1024 * 1024)
424 .initial_connection_window_size(2 * 1024 * 1024)
425 .max_concurrent_streams(250)
426 .serve_connection(io, svc);
427 if let Err(e) = conn.await {
428 tracing::debug!("HTTPS/2 connection error: {}", e);
429 }
430 } else {
431 let svc = service_fn(move |req| {
432 handle_request(
433 req,
434 client.clone(),
435 config.clone(),
436 metrics.clone(),
437 challenge_store.clone(),
438 lua_engine.clone(),
439 circuit_breaker.clone(),
440 app_manager.clone(),
441 )
442 });
443 let conn = hyper::server::conn::http1::Builder::new()
444 .keep_alive(true)
445 .pipeline_flush(true)
446 .serve_connection(io, svc)
447 .with_upgrades();
448 if let Err(e) = conn.await {
449 tracing::debug!("HTTPS/1.1 connection error: {}", e);
450 }
451 }
452
453 Ok(())
454}
455
456#[cfg(feature = "scripting")]
458fn extract_headers(req: &Request<Incoming>) -> std::collections::HashMap<String, String> {
459 req.headers()
460 .iter()
461 .map(|(k, v)| {
462 (
463 k.as_str().to_lowercase(),
464 v.to_str().unwrap_or("").to_string(),
465 )
466 })
467 .collect()
468}
469
470#[cfg(feature = "scripting")]
472fn build_lua_request(req: &Request<Incoming>) -> LuaRequest {
473 let host = req
474 .uri()
475 .host()
476 .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
477 .unwrap_or("")
478 .to_string();
479
480 let content_length = req
481 .headers()
482 .get("content-length")
483 .and_then(|v| v.to_str().ok())
484 .and_then(|v| v.parse().ok())
485 .unwrap_or(0);
486
487 LuaRequest {
488 method: req.method().to_string(),
489 path: req.uri().path().to_string(),
490 headers: extract_headers(req),
491 host,
492 content_length,
493 }
494}
495
496#[cfg(feature = "scripting")]
498fn extract_response_headers(
499 headers: &hyper::HeaderMap,
500) -> std::collections::HashMap<String, String> {
501 headers
502 .iter()
503 .map(|(k, v)| {
504 (
505 k.as_str().to_lowercase(),
506 v.to_str().unwrap_or("").to_string(),
507 )
508 })
509 .collect()
510}
511
512#[allow(clippy::too_many_arguments)]
513async fn handle_request(
514 req: Request<Incoming>,
515 client: ClientType,
516 config_manager: Arc<ConfigManager>,
517 metrics: SharedMetrics,
518 challenge_store: ChallengeStore,
519 lua_engine: OptionalLuaEngine,
520 circuit_breaker: SharedCircuitBreaker,
521 app_manager: Option<Arc<AppManager>>,
522) -> Result<Response<BoxBody>, hyper::Error> {
523 let start_time = std::time::Instant::now();
524 metrics.inc_in_flight();
525 let config = config_manager.get_config();
526
527 if let Some(response) = handle_acme_challenge(&req, &challenge_store) {
529 metrics.dec_in_flight();
530 return Ok(response);
531 }
532
533 if is_metrics_request(&req) {
534 let duration = start_time.elapsed();
535 metrics.dec_in_flight();
536 let metrics_output = metrics.format_metrics();
537 metrics.record_request(0, metrics_output.len() as u64, 200, duration);
538 let body = http_body_util::Full::new(Bytes::from(metrics_output)).boxed();
539 return Ok(Response::builder()
540 .status(200)
541 .header("Content-Type", "text/plain")
542 .body(body)
543 .unwrap());
544 }
545
546 #[cfg(feature = "scripting")]
548 if let Some(ref engine) = lua_engine {
549 if engine.has_on_request() {
550 let mut lua_req = build_lua_request(&req);
551 match engine.call_on_request(&mut lua_req) {
552 RequestHookResult::Deny { status, body } => {
553 metrics.dec_in_flight();
554 let duration = start_time.elapsed();
555 metrics.record_request(0, body.len() as u64, status, duration);
556 let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
557 return Ok(Response::builder().status(status).body(resp_body).unwrap());
558 }
559 RequestHookResult::Continue(updated_req) => {
560 let _ = updated_req;
566 }
567 }
568 }
569 }
570
571 let is_websocket = is_websocket_request(&req);
572
573 if is_websocket {
574 return handle_websocket_request(
575 req,
576 client,
577 &config,
578 &metrics,
579 start_time,
580 app_manager.clone(),
581 )
582 .await;
583 }
584
585 let result = handle_regular_request(
586 req,
587 client,
588 &config,
589 &lua_engine,
590 &circuit_breaker,
591 app_manager.clone(),
592 )
593 .await;
594 let duration = start_time.elapsed();
595
596 metrics.dec_in_flight();
597
598 match result {
599 #[allow(unused_variables)]
600 Ok((response, _target_url, route_scripts)) => {
601 let status = response.status().as_u16();
602
603 #[cfg(feature = "scripting")]
605 if let Some(ref engine) = lua_engine {
606 let lua_req = LuaRequest {
607 method: String::new(),
608 path: String::new(),
609 headers: std::collections::HashMap::new(),
610 host: String::new(),
611 content_length: 0,
612 };
613 let duration_ms = duration.as_secs_f64() * 1000.0;
614
615 if engine.has_on_request_end() {
617 engine.call_on_request_end(&lua_req, status, duration_ms, &_target_url);
618 }
619
620 for script_name in &route_scripts {
622 engine.call_route_on_request_end(
623 script_name,
624 &lua_req,
625 status,
626 duration_ms,
627 &_target_url,
628 );
629 }
630 }
631
632 metrics.record_request(0, 0, status, duration);
633 record_app_metrics(&metrics, &app_manager, &_target_url, 0, 0, status, duration);
634 let (parts, body) = response.into_parts();
635 let boxed = body.map_err(|_| unreachable!()).boxed();
636 Ok(Response::from_parts(parts, boxed))
637 }
638 Err(e) => {
639 metrics.inc_errors();
640 Err(e)
641 }
642 }
643}
644
645fn is_websocket_request(req: &Request<Incoming>) -> bool {
646 if let Some(upgrade) = req.headers().get("upgrade") {
647 if upgrade == "websocket" {
648 return true;
649 }
650 }
651 false
652}
653
654fn is_metrics_request(req: &Request<Incoming>) -> bool {
655 req.uri().path() == "/metrics"
656}
657
658fn handle_acme_challenge(
659 req: &Request<Incoming>,
660 challenge_store: &ChallengeStore,
661) -> Option<Response<BoxBody>> {
662 let path = req.uri().path();
663 let prefix = "/.well-known/acme-challenge/";
664
665 if !path.starts_with(prefix) {
666 return None;
667 }
668
669 let token = &path[prefix.len()..];
670
671 if let Ok(store) = challenge_store.read() {
672 if let Some(key_auth) = store.get(token) {
673 let body = http_body_util::Full::new(Bytes::from(key_auth.clone())).boxed();
674 return Some(
675 Response::builder()
676 .status(200)
677 .header("Content-Type", "text/plain")
678 .body(body)
679 .unwrap(),
680 );
681 }
682 }
683
684 let body = http_body_util::Full::new(Bytes::from("Challenge not found")).boxed();
685 Some(Response::builder().status(404).body(body).unwrap())
686}
687
688async fn handle_websocket_request(
689 req: Request<Incoming>,
690 _client: ClientType,
691 config: &crate::config::Config,
692 metrics: &SharedMetrics,
693 _start_time: std::time::Instant,
694 _app_manager: Option<Arc<AppManager>>,
695) -> Result<Response<BoxBody>, hyper::Error> {
696 let target_result = find_target(&req, &config.rules);
697
698 if target_result.is_none() {
699 metrics.inc_errors();
700 let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
701 return Ok(Response::builder().status(421).body(body).unwrap());
702 }
703
704 let (target_url, _, _, _) = target_result.unwrap();
705
706 let backend_addr = match url::Url::parse(&target_url) {
708 Ok(u) => format!(
709 "{}:{}",
710 u.host_str().unwrap_or("127.0.0.1"),
711 u.port().unwrap_or(80)
712 ),
713 Err(_) => {
714 metrics.inc_errors();
715 let body = http_body_util::Full::new(Bytes::from("Bad backend URL")).boxed();
716 return Ok(Response::builder().status(502).body(body).unwrap());
717 }
718 };
719
720 let path = req.uri().path().to_string();
721 let query = req
722 .uri()
723 .query()
724 .map(|q| format!("?{}", q))
725 .unwrap_or_default();
726
727 let ws_key = req
728 .headers()
729 .get("sec-websocket-key")
730 .and_then(|v| v.to_str().ok())
731 .unwrap_or("")
732 .to_string();
733 let ws_version = req
734 .headers()
735 .get("sec-websocket-version")
736 .and_then(|v| v.to_str().ok())
737 .unwrap_or("13")
738 .to_string();
739 let ws_protocol = req
740 .headers()
741 .get("sec-websocket-protocol")
742 .and_then(|v| v.to_str().ok())
743 .map(|s| s.to_string());
744 let host_header = req
745 .headers()
746 .get("host")
747 .and_then(|v| v.to_str().ok())
748 .unwrap_or(&backend_addr)
749 .to_string();
750
751 tracing::info!(
752 "WebSocket upgrade request to {}{}{}",
753 backend_addr,
754 path,
755 query
756 );
757
758 let backend = match TcpStream::connect(&backend_addr).await {
760 Ok(s) => s,
761 Err(e) => {
762 tracing::error!("Failed to connect to backend for WebSocket: {}", e);
763 metrics.inc_errors();
764 let body = http_body_util::Full::new(Bytes::from("Backend not reachable")).boxed();
765 return Ok(Response::builder().status(502).body(body).unwrap());
766 }
767 };
768
769 let mut handshake = format!(
771 "GET {}{} HTTP/1.1\r\n\
772 Host: {}\r\n\
773 Upgrade: websocket\r\n\
774 Connection: Upgrade\r\n\
775 Sec-WebSocket-Key: {}\r\n\
776 Sec-WebSocket-Version: {}\r\n",
777 path, query, host_header, ws_key, ws_version,
778 );
779 if let Some(proto) = &ws_protocol {
780 handshake.push_str(&format!("Sec-WebSocket-Protocol: {}\r\n", proto));
781 }
782 handshake.push_str("\r\n");
783
784 let (mut backend_read, mut backend_write) = backend.into_split();
785 if let Err(e) = backend_write.write_all(handshake.as_bytes()).await {
786 tracing::error!("Failed to send WebSocket handshake to backend: {}", e);
787 metrics.inc_errors();
788 let body =
789 http_body_util::Full::new(Bytes::from("Failed to initiate WebSocket with backend"))
790 .boxed();
791 return Ok(Response::builder().status(502).body(body).unwrap());
792 }
793
794 let mut response_buf = vec![0u8; 4096];
796 let n = match tokio::io::AsyncReadExt::read(&mut backend_read, &mut response_buf).await {
797 Ok(n) if n > 0 => n,
798 _ => {
799 tracing::error!("No response from backend for WebSocket upgrade");
800 metrics.inc_errors();
801 let body = http_body_util::Full::new(Bytes::from(
802 "Backend did not respond to WebSocket upgrade",
803 ))
804 .boxed();
805 return Ok(Response::builder().status(502).body(body).unwrap());
806 }
807 };
808
809 let response_str = String::from_utf8_lossy(&response_buf[..n]);
810 if !response_str.contains("101") {
811 tracing::error!(
812 "Backend rejected WebSocket upgrade: {}",
813 response_str.lines().next().unwrap_or("")
814 );
815 metrics.inc_errors();
816 let body =
817 http_body_util::Full::new(Bytes::from("Backend rejected WebSocket upgrade")).boxed();
818 return Ok(Response::builder().status(502).body(body).unwrap());
819 }
820
821 let mut accept_key = String::new();
823 let mut resp_protocol = None;
824 for line in response_str.lines().skip(1) {
825 if line.trim().is_empty() {
826 break;
827 }
828 if let Some((name, value)) = line.split_once(':') {
829 let name_lower = name.trim().to_lowercase();
830 let value = value.trim().to_string();
831 if name_lower == "sec-websocket-accept" {
832 accept_key = value;
833 } else if name_lower == "sec-websocket-protocol" {
834 resp_protocol = Some(value);
835 }
836 }
837 }
838
839 let client_upgrade = hyper::upgrade::on(req);
841
842 let backend_stream = backend_read.reunite(backend_write).unwrap();
844
845 tokio::spawn(async move {
847 match client_upgrade.await {
848 Ok(upgraded) => {
849 let mut client_stream = TokioIo::new(upgraded);
850 let (mut br, mut bw) = tokio::io::split(backend_stream);
851 let (mut cr, mut cw) = tokio::io::split(&mut client_stream);
852 let _ = tokio::join!(
853 tokio::io::copy(&mut br, &mut cw),
854 tokio::io::copy(&mut cr, &mut bw),
855 );
856 }
857 Err(e) => {
858 tracing::error!("WebSocket client upgrade failed: {}", e);
859 }
860 }
861 });
862
863 let mut resp = Response::builder()
865 .status(101)
866 .header("Upgrade", "websocket")
867 .header("Connection", "Upgrade")
868 .header("Sec-WebSocket-Accept", accept_key);
869 if let Some(proto) = resp_protocol {
870 resp = resp.header("Sec-WebSocket-Protocol", proto);
871 }
872 Ok(resp
873 .body(http_body_util::Full::new(Bytes::new()).boxed())
874 .unwrap())
875}
876
877async fn handle_regular_request(
879 req: Request<Incoming>,
880 client: ClientType,
881 config: &crate::config::Config,
882 lua_engine: &OptionalLuaEngine,
883 circuit_breaker: &SharedCircuitBreaker,
884 _app_manager: Option<Arc<AppManager>>,
885) -> Result<(Response<BoxBody>, String, Vec<String>), hyper::Error> {
886 let route = find_matching_rule(&req, &config.rules);
887
888 match route {
889 #[allow(unused_mut, unused_variables)]
890 Some(matched_route) => {
891 let path = req.uri().path().to_string();
892 let from_domain_rule = matched_route.from_domain_rule;
893 let matched_prefix = matched_route.matched_prefix();
894 let route_scripts = matched_route.route_scripts.clone();
895
896 let target_selection = select_target(&matched_route, &path, circuit_breaker);
898 let (mut target_url, base_url) = match target_selection {
899 Some((url, base)) => (url, base),
900 None => {
901 let body =
903 http_body_util::Full::new(Bytes::from("Service Unavailable")).boxed();
904 return Ok((
905 Response::builder()
906 .status(503)
907 .body(body)
908 .expect("Failed to build response"),
909 String::new(),
910 route_scripts,
911 ));
912 }
913 };
914 #[cfg(feature = "scripting")]
916 if let Some(ref engine) = lua_engine {
917 for script_name in &route_scripts {
918 let mut lua_req = build_lua_request(&req);
919 match engine.call_route_on_request(script_name, &mut lua_req) {
920 RequestHookResult::Deny { status, body } => {
921 let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
922 return Ok((
923 Response::builder().status(status).body(resp_body).unwrap(),
924 target_url,
925 route_scripts.clone(),
926 ));
927 }
928 RequestHookResult::Continue(_) => {}
929 }
930 }
931 }
932
933 #[cfg(feature = "scripting")]
935 if let Some(ref engine) = lua_engine {
936 if engine.has_on_route() {
937 let lua_req = build_lua_request(&req);
938 match engine.call_on_route(&lua_req, &target_url) {
939 RouteHookResult::Override(new_url) => {
940 target_url = new_url;
941 }
942 RouteHookResult::Default => {}
943 }
944 }
945 for script_name in &route_scripts {
947 let lua_req = build_lua_request(&req);
948 match engine.call_route_on_route(script_name, &lua_req, &target_url) {
949 RouteHookResult::Override(new_url) => {
950 target_url = new_url;
951 }
952 RouteHookResult::Default => {}
953 }
954 }
955 }
956
957 let host_header = if from_domain_rule {
959 req.uri()
960 .host()
961 .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
962 .map(|s| s.to_string())
963 } else {
964 None
965 };
966
967 let (mut parts, body) = req.into_parts();
968
969 let uri: hyper::Uri = target_url.parse().expect("valid URI");
971 parts.uri = uri;
972 parts.version = http::Version::HTTP_11;
973 parts.extensions = http::Extensions::new();
974
975 let mut request = Request::from_parts(parts, body);
976
977 request
978 .headers_mut()
979 .insert("X-Forwarded-For", X_FORWARDED_FOR_VALUE.clone());
980
981 if from_domain_rule {
982 if let Some(host) = host_header {
983 request
984 .headers_mut()
985 .insert("X-Forwarded-Host", host.parse().unwrap());
986 }
987 }
988
989 match client.request(request).await {
990 Ok(response) => {
991 let status_code = response.status().as_u16();
993 if circuit_breaker.is_failure_status(status_code) {
994 circuit_breaker.record_failure(&base_url);
995 } else {
996 circuit_breaker.record_success(&base_url);
997 }
998
999 #[cfg(feature = "scripting")]
1001 if let Some(ref engine) = lua_engine {
1002 let has_global = engine.has_on_response();
1003 let has_route = !route_scripts.is_empty();
1004
1005 if has_global || has_route {
1006 use crate::scripting::ResponseMod;
1007
1008 let lua_req = LuaRequest {
1009 method: String::new(),
1010 path: String::new(),
1011 headers: std::collections::HashMap::new(),
1012 host: String::new(),
1013 content_length: 0,
1014 };
1015 let resp_headers = extract_response_headers(response.headers());
1016 let resp_status = response.status().as_u16();
1017
1018 let mut all_mods: Vec<ResponseMod> = Vec::new();
1020 if has_global {
1021 all_mods.push(engine.call_on_response(
1022 &lua_req,
1023 resp_status,
1024 &resp_headers,
1025 ));
1026 }
1027 for script_name in &route_scripts {
1028 all_mods.push(engine.call_route_on_response(
1029 script_name,
1030 &lua_req,
1031 resp_status,
1032 &resp_headers,
1033 ));
1034 }
1035
1036 let mut merged = ResponseMod::default();
1038 for mods in all_mods {
1039 merged.set_headers.extend(mods.set_headers);
1040 merged.remove_headers.extend(mods.remove_headers);
1041 if mods.replace_body.is_some() {
1042 merged.replace_body = mods.replace_body;
1043 }
1044 if mods.override_status.is_some() {
1045 merged.override_status = mods.override_status;
1046 }
1047 }
1048
1049 if !merged.set_headers.is_empty()
1051 || !merged.remove_headers.is_empty()
1052 || merged.replace_body.is_some()
1053 || merged.override_status.is_some()
1054 {
1055 let (mut parts, body) = response.into_parts();
1056
1057 if let Some(status) = merged.override_status {
1058 parts.status =
1059 hyper::StatusCode::from_u16(status).unwrap_or(parts.status);
1060 }
1061
1062 for name in &merged.remove_headers {
1063 if let Ok(header_name) =
1064 name.parse::<hyper::header::HeaderName>()
1065 {
1066 parts.headers.remove(header_name);
1067 }
1068 }
1069
1070 for (name, value) in &merged.set_headers {
1071 if let (Ok(header_name), Ok(header_value)) = (
1072 name.parse::<hyper::header::HeaderName>(),
1073 value.parse::<HeaderValue>(),
1074 ) {
1075 parts.headers.insert(header_name, header_value);
1076 }
1077 }
1078
1079 if let Some(new_body) = merged.replace_body {
1080 let new_bytes = Bytes::from(new_body);
1081 parts.headers.remove("content-length");
1082 parts.headers.insert(
1083 "content-length",
1084 new_bytes.len().to_string().parse().unwrap(),
1085 );
1086 let boxed = http_body_util::Full::new(new_bytes).boxed();
1087 return Ok((
1088 Response::from_parts(parts, boxed),
1089 target_url,
1090 route_scripts.clone(),
1091 ));
1092 }
1093
1094 let boxed = body.map_err(|_| unreachable!()).boxed();
1095 return Ok((
1096 Response::from_parts(parts, boxed),
1097 target_url,
1098 route_scripts.clone(),
1099 ));
1100 }
1101 }
1102 }
1103
1104 let is_html = response
1105 .headers()
1106 .get("content-type")
1107 .and_then(|v| v.to_str().ok())
1108 .map(|ct| ct.starts_with("text/html"))
1109 .unwrap_or(false);
1110
1111 if is_html {
1112 if let Some(prefix) = matched_prefix {
1113 let (parts, body) = response.into_parts();
1114 let body_bytes = body
1115 .collect()
1116 .await
1117 .map(|collected| collected.to_bytes())
1118 .unwrap_or_default();
1119
1120 let is_gzip = parts
1122 .headers
1123 .get("content-encoding")
1124 .and_then(|v| v.to_str().ok())
1125 .map(|v| v.contains("gzip"))
1126 .unwrap_or(false);
1127 let is_deflate = parts
1128 .headers
1129 .get("content-encoding")
1130 .and_then(|v| v.to_str().ok())
1131 .map(|v| v.contains("deflate"))
1132 .unwrap_or(false);
1133
1134 let raw_bytes = if is_gzip {
1135 use std::io::Read;
1136 let mut decoder = flate2::read::GzDecoder::new(&body_bytes[..]);
1137 let mut decoded = Vec::new();
1138 decoder.read_to_end(&mut decoded).unwrap_or_default();
1139 Bytes::from(decoded)
1140 } else if is_deflate {
1141 use std::io::Read;
1142 let mut decoder =
1143 flate2::read::DeflateDecoder::new(&body_bytes[..]);
1144 let mut decoded = Vec::new();
1145 decoder.read_to_end(&mut decoded).unwrap_or_default();
1146 Bytes::from(decoded)
1147 } else {
1148 body_bytes
1149 };
1150
1151 let html = String::from_utf8_lossy(&raw_bytes);
1152 let rewritten = html
1153 .replace("href=\"/", &format!("href=\"{}/", prefix))
1154 .replace("src=\"/", &format!("src=\"{}/", prefix))
1155 .replace("action=\"/", &format!("action=\"{}/", prefix));
1156 let rewritten_bytes = Bytes::from(rewritten);
1157 let mut parts = parts;
1158 parts.headers.remove("content-encoding");
1159 parts.headers.remove("content-length");
1160 parts.headers.insert(
1161 "content-length",
1162 rewritten_bytes.len().to_string().parse().unwrap(),
1163 );
1164 let boxed = http_body_util::Full::new(rewritten_bytes).boxed();
1165 return Ok((
1166 Response::from_parts(parts, boxed),
1167 target_url,
1168 route_scripts.clone(),
1169 ));
1170 }
1171 }
1172
1173 let (parts, body) = response.into_parts();
1174 let boxed = body.map_err(|_| unreachable!()).boxed();
1175 Ok((
1176 Response::from_parts(parts, boxed),
1177 target_url,
1178 route_scripts,
1179 ))
1180 }
1181 Err(e) => {
1182 circuit_breaker.record_failure(&base_url);
1183 tracing::error!("Backend request failed: {} (target: {})", e, target_url);
1184 let body = http_body_util::Full::new(Bytes::from("Bad Gateway")).boxed();
1185 Ok((
1186 Response::builder()
1187 .status(502)
1188 .body(body)
1189 .expect("Failed to build response"),
1190 target_url,
1191 route_scripts,
1192 ))
1193 }
1194 }
1195 }
1196 None => {
1197 let _ = lua_engine;
1199 let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
1200 Ok((
1201 Response::builder()
1202 .status(421)
1203 .body(body)
1204 .expect("Failed to build response"),
1205 String::new(),
1206 vec![],
1207 ))
1208 }
1209 }
1210}
1211
1212enum UrlResolution {
1214 AppendPath,
1216 StripPrefix(String),
1218 Identity,
1220}
1221
1222struct MatchedRoute<'a> {
1224 targets: &'a [crate::config::Target],
1225 from_domain_rule: bool,
1226 resolution: UrlResolution,
1227 route_scripts: Vec<String>,
1228}
1229
1230impl<'a> MatchedRoute<'a> {
1231 fn matched_prefix(&self) -> Option<String> {
1232 match &self.resolution {
1233 UrlResolution::StripPrefix(prefix) => Some(prefix.trim_end_matches('/').to_string()),
1234 _ => None,
1235 }
1236 }
1237}
1238
1239fn resolve_target_url(
1241 target: &crate::config::Target,
1242 path: &str,
1243 resolution: &UrlResolution,
1244) -> String {
1245 let target_str = target.url.as_str();
1246 match resolution {
1247 UrlResolution::AppendPath => {
1248 if target_str.ends_with('/') {
1249 format!("{}{}", target_str, &path[1..])
1250 } else {
1251 format!("{}{}", target_str, path)
1252 }
1253 }
1254 UrlResolution::StripPrefix(prefix) => {
1255 let suffix = if path.len() >= prefix.len() {
1256 &path[prefix.len()..]
1257 } else {
1258 ""
1259 };
1260 format!("{}{}", target_str, suffix)
1261 }
1262 UrlResolution::Identity => target_str.to_owned(),
1263 }
1264}
1265
1266fn find_matching_rule<'a>(
1268 req: &Request<Incoming>,
1269 rules: &'a [crate::config::ProxyRule],
1270) -> Option<MatchedRoute<'a>> {
1271 let host = req
1272 .uri()
1273 .host()
1274 .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1275 .map(|h| h.split(':').next().unwrap_or(h))?;
1276
1277 let path = req.uri().path();
1278 let mut matched_domain = false;
1279
1280 for rule in rules {
1281 match &rule.matcher {
1282 crate::config::RuleMatcher::Domain(domain) => {
1283 if domain == host {
1284 matched_domain = true;
1285 if !rule.targets.is_empty() {
1286 return Some(MatchedRoute {
1287 targets: &rule.targets,
1288 from_domain_rule: true,
1289 resolution: UrlResolution::AppendPath,
1290 route_scripts: rule.scripts.clone(),
1291 });
1292 }
1293 }
1294 }
1295 crate::config::RuleMatcher::DomainPath(domain, path_prefix) => {
1296 if domain == host && !rule.targets.is_empty() {
1297 let matches = path.starts_with(path_prefix)
1298 || (path_prefix.ends_with('/')
1299 && path == path_prefix.trim_end_matches('/'));
1300 if matches {
1301 return Some(MatchedRoute {
1302 targets: &rule.targets,
1303 from_domain_rule: true,
1304 resolution: UrlResolution::StripPrefix(path_prefix.clone()),
1305 route_scripts: rule.scripts.clone(),
1306 });
1307 }
1308 }
1309 }
1310 _ => {}
1311 }
1312 }
1313
1314 if matched_domain {
1315 return None;
1316 }
1317
1318 for rule in rules {
1320 match &rule.matcher {
1321 crate::config::RuleMatcher::Exact(exact) => {
1322 if path == exact && !rule.targets.is_empty() {
1323 return Some(MatchedRoute {
1324 targets: &rule.targets,
1325 from_domain_rule: false,
1326 resolution: UrlResolution::Identity,
1327 route_scripts: rule.scripts.clone(),
1328 });
1329 }
1330 }
1331 crate::config::RuleMatcher::Prefix(prefix) => {
1332 if !rule.targets.is_empty() {
1333 let matches = path.starts_with(prefix)
1335 || (prefix.ends_with('/') && path == prefix.trim_end_matches('/'));
1336 if matches {
1337 return Some(MatchedRoute {
1338 targets: &rule.targets,
1339 from_domain_rule: false,
1340 resolution: UrlResolution::StripPrefix(prefix.clone()),
1341 route_scripts: rule.scripts.clone(),
1342 });
1343 }
1344 }
1345 }
1346 crate::config::RuleMatcher::Regex(ref rm) => {
1347 if rm.is_match(path) && !rule.targets.is_empty() {
1348 return Some(MatchedRoute {
1349 targets: &rule.targets,
1350 from_domain_rule: false,
1351 resolution: UrlResolution::Identity,
1352 route_scripts: rule.scripts.clone(),
1353 });
1354 }
1355 }
1356 _ => {}
1357 }
1358 }
1359
1360 for rule in rules {
1362 if let crate::config::RuleMatcher::Default = &rule.matcher {
1363 if !rule.targets.is_empty() {
1364 return Some(MatchedRoute {
1365 targets: &rule.targets,
1366 from_domain_rule: false,
1367 resolution: UrlResolution::AppendPath,
1368 route_scripts: rule.scripts.clone(),
1369 });
1370 }
1371 }
1372 }
1373
1374 None
1375}
1376
1377fn select_target(
1380 route: &MatchedRoute<'_>,
1381 path: &str,
1382 circuit_breaker: &crate::circuit_breaker::CircuitBreaker,
1383) -> Option<(String, String)> {
1384 for target in route.targets {
1385 let base_url = target.url.as_str().to_owned();
1386 if circuit_breaker.is_available(&base_url) {
1387 let resolved = resolve_target_url(target, path, &route.resolution);
1388 return Some((resolved, base_url));
1389 }
1390 }
1391 None
1392}
1393
1394fn find_target(
1396 req: &Request<Incoming>,
1397 rules: &[crate::config::ProxyRule],
1398) -> Option<(String, bool, Option<String>, Vec<String>)> {
1399 let route = find_matching_rule(req, rules)?;
1400 let path = req.uri().path();
1401 let target = route.targets.first()?;
1402 let resolved = resolve_target_url(target, path, &route.resolution);
1403 let matched_prefix = route.matched_prefix();
1404 Some((
1405 resolved,
1406 route.from_domain_rule,
1407 matched_prefix,
1408 route.route_scripts,
1409 ))
1410}