Skip to main content

soli_proxy/server/
mod.rs

1// When scripting feature is disabled, OptionalLuaEngine = () and cloning it triggers warnings
2#![allow(clippy::let_unit_value, clippy::clone_on_copy, clippy::unit_arg)]
3
4use crate::acme::ChallengeStore;
5use crate::app::AppManager;
6use crate::auth;
7use crate::circuit_breaker::SharedCircuitBreaker;
8use crate::config::ConfigManager;
9use crate::metrics::SharedMetrics;
10use crate::shutdown::ShutdownCoordinator;
11use anyhow::Result;
12use bytes::Bytes;
13use http_body_util::BodyExt;
14use hyper::body::Incoming;
15use hyper::header::HeaderValue;
16use hyper::service::service_fn;
17use hyper::Request;
18use hyper::Response;
19use hyper_util::client::legacy::connect::HttpConnector;
20use hyper_util::client::legacy::Client;
21use hyper_util::rt::TokioExecutor;
22use hyper_util::rt::TokioIo;
23use socket2::{Domain, Protocol, Socket, Type};
24use std::net::SocketAddr;
25use std::sync::Arc;
26use std::time::Duration;
27use tokio::io::AsyncWriteExt;
28use tokio::net::{TcpListener, TcpStream};
29use tokio_rustls::TlsAcceptor;
30
31#[cfg(feature = "scripting")]
32use crate::scripting::{LuaEngine, LuaRequest, RequestHookResult, RouteHookResult};
33
34type ClientType = Client<HttpConnector, Incoming>;
35type BoxBody = http_body_util::combinators::BoxBody<Bytes, std::convert::Infallible>;
36
37#[cfg(feature = "scripting")]
38type OptionalLuaEngine = Option<LuaEngine>;
39#[cfg(not(feature = "scripting"))]
40type OptionalLuaEngine = ();
41
42/// Helper to record app-specific metrics
43fn record_app_metrics(
44    metrics: &SharedMetrics,
45    app_manager: &Option<Arc<AppManager>>,
46    target_url: &str,
47    bytes_in: u64,
48    bytes_out: u64,
49    status: u16,
50    duration: Duration,
51) {
52    if let Some(ref manager) = app_manager {
53        if let Ok(url) = url::Url::parse(target_url) {
54            if let Some(port) = url.port() {
55                if let Some(app_name) = futures::executor::block_on(manager.get_app_name(port)) {
56                    metrics.record_app_request(&app_name, bytes_in, bytes_out, status, duration);
57                }
58            }
59        }
60    }
61}
62
63/// Pre-parsed header value for X-Forwarded-For to avoid parsing on every request
64static X_FORWARDED_FOR_VALUE: std::sync::LazyLock<HeaderValue> =
65    std::sync::LazyLock::new(|| HeaderValue::from_static("127.0.0.1"));
66
67/// Verify Basic Auth credentials against stored hashes
68/// Returns true if credentials are valid, false otherwise
69fn verify_basic_auth(req: &Request<Incoming>, auth_entries: &[crate::auth::BasicAuth]) -> bool {
70    if auth_entries.is_empty() {
71        return true;
72    }
73
74    let auth_header = req.headers().get("authorization");
75    if auth_header.is_none() {
76        return false;
77    }
78
79    let header_value = auth_header.unwrap().to_str().unwrap_or("");
80    if !header_value.starts_with("Basic ") {
81        return false;
82    }
83
84    let encoded = &header_value[6..];
85    let decoded = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, encoded)
86        .unwrap_or_default();
87    let creds = String::from_utf8_lossy(&decoded);
88
89    if let Some((username, password)) = creds.split_once(':') {
90        for entry in auth_entries {
91            if entry.username == username && auth::verify_password(password, &entry.hash) {
92                return true;
93            }
94        }
95    }
96
97    false
98}
99
100/// Create 401 Unauthorized response with WWW-Authenticate header
101fn create_auth_required_response() -> Response<BoxBody> {
102    let body = http_body_util::Full::new(Bytes::from("Authentication required")).boxed();
103    Response::builder()
104        .status(401)
105        .header("WWW-Authenticate", "Basic realm=\"Restricted\"")
106        .body(body)
107        .unwrap()
108}
109
110fn create_listener(addr: SocketAddr) -> Result<TcpListener> {
111    let domain = if addr.is_ipv4() {
112        Domain::IPV4
113    } else {
114        Domain::IPV6
115    };
116    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
117    socket.set_reuse_address(true)?;
118    socket.set_reuse_port(true)?;
119    socket.set_nonblocking(true)?;
120    socket.bind(&addr.into())?;
121    socket.listen(8192)?;
122    let std_listener: std::net::TcpListener = socket.into();
123    Ok(TcpListener::from_std(std_listener)?)
124}
125
126fn create_client() -> ClientType {
127    let exec = TokioExecutor::new();
128    let mut connector = HttpConnector::new();
129    connector.set_nodelay(true);
130    connector.set_keepalive(Some(Duration::from_secs(30)));
131    connector.set_connect_timeout(Some(Duration::from_secs(5)));
132    Client::builder(exec)
133        .pool_max_idle_per_host(256)
134        .pool_idle_timeout(Duration::from_secs(60))
135        .build(connector)
136}
137
138pub struct ProxyServer {
139    config: Arc<ConfigManager>,
140    shutdown: ShutdownCoordinator,
141    tls_acceptor: Option<TlsAcceptor>,
142    https_addr: Option<SocketAddr>,
143    metrics: SharedMetrics,
144    challenge_store: ChallengeStore,
145    lua_engine: OptionalLuaEngine,
146    circuit_breaker: SharedCircuitBreaker,
147    app_manager: Option<Arc<AppManager>>,
148}
149
150impl ProxyServer {
151    pub fn new(
152        config: Arc<ConfigManager>,
153        shutdown: ShutdownCoordinator,
154        metrics: SharedMetrics,
155        challenge_store: ChallengeStore,
156        lua_engine: OptionalLuaEngine,
157        circuit_breaker: SharedCircuitBreaker,
158        app_manager: Option<Arc<AppManager>>,
159    ) -> Result<Self> {
160        Ok(Self {
161            config,
162            shutdown,
163            tls_acceptor: None,
164            https_addr: None,
165            metrics,
166            challenge_store,
167            lua_engine,
168            circuit_breaker,
169            app_manager,
170        })
171    }
172
173    #[allow(clippy::too_many_arguments)]
174    pub fn with_https(
175        config: Arc<ConfigManager>,
176        shutdown: ShutdownCoordinator,
177        tls_acceptor: TlsAcceptor,
178        https_addr: SocketAddr,
179        metrics: SharedMetrics,
180        challenge_store: ChallengeStore,
181        lua_engine: OptionalLuaEngine,
182        circuit_breaker: SharedCircuitBreaker,
183        app_manager: Option<Arc<AppManager>>,
184    ) -> Result<Self> {
185        Ok(Self {
186            config,
187            shutdown,
188            tls_acceptor: Some(tls_acceptor),
189            https_addr: Some(https_addr),
190            metrics,
191            challenge_store,
192            lua_engine,
193            circuit_breaker,
194            app_manager,
195        })
196    }
197
198    pub async fn run(&self) -> Result<()> {
199        let cfg = self.config.get_config();
200        let http_addr: SocketAddr = cfg.server.bind.parse()?;
201        let https_addr = self.https_addr;
202
203        let has_https = https_addr.is_some();
204        let num_listeners = std::thread::available_parallelism()
205            .map(|n| n.get())
206            .unwrap_or(4);
207
208        // Spawn N HTTP accept loops with SO_REUSEPORT
209        // Each listener gets its own client with its own connection pool to avoid contention
210        let app_manager = self.app_manager.clone();
211        for i in 0..num_listeners {
212            let config_clone = self.config.clone();
213            let shutdown_clone = self.shutdown.clone();
214            let metrics_clone = self.metrics.clone();
215            let challenge_store_clone = self.challenge_store.clone();
216            let lua_clone = self.lua_engine.clone();
217            let cb_clone = self.circuit_breaker.clone();
218            let am_clone = app_manager.clone();
219
220            tokio::spawn(async move {
221                if let Err(e) = run_http_server(
222                    http_addr,
223                    config_clone,
224                    shutdown_clone,
225                    metrics_clone,
226                    challenge_store_clone,
227                    lua_clone,
228                    cb_clone,
229                    am_clone,
230                )
231                .await
232                {
233                    tracing::error!("HTTP/1.1 server error (listener {}): {}", i, e);
234                }
235            });
236        }
237
238        if let Some(https_addr) = https_addr {
239            for i in 0..num_listeners {
240                let config_clone = self.config.clone();
241                let shutdown_clone = self.shutdown.clone();
242                let acceptor = self.tls_acceptor.as_ref().unwrap().clone();
243                let metrics_clone = self.metrics.clone();
244                let challenge_store_clone = self.challenge_store.clone();
245                let lua_clone = self.lua_engine.clone();
246                let cb_clone = self.circuit_breaker.clone();
247                let am_clone = app_manager.clone();
248
249                tokio::spawn(async move {
250                    if let Err(e) = run_https_server(
251                        https_addr,
252                        config_clone,
253                        shutdown_clone,
254                        acceptor,
255                        metrics_clone,
256                        challenge_store_clone,
257                        lua_clone,
258                        cb_clone,
259                        am_clone,
260                    )
261                    .await
262                    {
263                        tracing::error!("HTTPS/2 server error (listener {}): {}", i, e);
264                    }
265                });
266            }
267        }
268
269        tracing::info!(
270            "HTTP/1.1 server listening on {} ({} accept loops)",
271            http_addr,
272            num_listeners
273        );
274        if has_https {
275            tracing::info!(
276                "HTTPS/2 server listening on {} ({} accept loops)",
277                https_addr.unwrap(),
278                num_listeners
279            );
280        }
281
282        loop {
283            if self.shutdown.is_shutting_down() {
284                tracing::info!("Shutting down servers...");
285                break;
286            }
287            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
288        }
289
290        Ok(())
291    }
292}
293
294#[allow(clippy::too_many_arguments)]
295async fn run_http_server(
296    addr: SocketAddr,
297    config: Arc<ConfigManager>,
298    shutdown: ShutdownCoordinator,
299    metrics: SharedMetrics,
300    challenge_store: ChallengeStore,
301    lua_engine: OptionalLuaEngine,
302    circuit_breaker: SharedCircuitBreaker,
303    app_manager: Option<Arc<AppManager>>,
304) -> Result<()> {
305    let listener = create_listener(addr)?;
306    let client = create_client();
307
308    loop {
309        if shutdown.is_shutting_down() {
310            break;
311        }
312
313        match listener.accept().await {
314            Ok((stream, _)) => {
315                let _ = stream.set_nodelay(true);
316                let client = client.clone();
317                let config = config.clone();
318                let metrics = metrics.clone();
319                let cs = challenge_store.clone();
320                let lua = lua_engine.clone();
321                let cb = circuit_breaker.clone();
322                let am = app_manager.clone();
323                tokio::spawn(async move {
324                    if let Err(e) =
325                        handle_http11_connection(stream, client, config, metrics, cs, lua, cb, am)
326                            .await
327                    {
328                        tracing::debug!("HTTP/1.1 connection error: {}", e);
329                    }
330                });
331            }
332            Err(e) => {
333                tracing::error!("HTTP/1.1 accept error: {}", e);
334            }
335        }
336    }
337
338    Ok(())
339}
340
341#[allow(clippy::too_many_arguments)]
342async fn run_https_server(
343    addr: SocketAddr,
344    config: Arc<ConfigManager>,
345    shutdown: ShutdownCoordinator,
346    acceptor: TlsAcceptor,
347    metrics: SharedMetrics,
348    challenge_store: ChallengeStore,
349    lua_engine: OptionalLuaEngine,
350    circuit_breaker: SharedCircuitBreaker,
351    app_manager: Option<Arc<AppManager>>,
352) -> Result<()> {
353    let listener = create_listener(addr)?;
354    let client = create_client();
355
356    loop {
357        if shutdown.is_shutting_down() {
358            break;
359        }
360
361        match listener.accept().await {
362            Ok((stream, _)) => {
363                let _ = stream.set_nodelay(true);
364                let client = client.clone();
365                let config = config.clone();
366                let acceptor = acceptor.clone();
367                let metrics = metrics.clone();
368                let cs = challenge_store.clone();
369                let lua = lua_engine.clone();
370                let cb = circuit_breaker.clone();
371                let am = app_manager.clone();
372                tokio::spawn(async move {
373                    match acceptor.accept(stream).await {
374                        Ok(tls_stream) => {
375                            metrics.inc_tls_connections();
376                            if let Err(e) = handle_https2_connection(
377                                tls_stream, client, config, metrics, cs, lua, cb, am,
378                            )
379                            .await
380                            {
381                                tracing::debug!("HTTPS/2 connection error: {}", e);
382                            }
383                        }
384                        Err(e) => {
385                            tracing::error!("TLS accept error: {}", e);
386                        }
387                    }
388                });
389            }
390            Err(e) => {
391                tracing::error!("HTTPS/2 accept error: {}", e);
392            }
393        }
394    }
395
396    Ok(())
397}
398
399#[allow(clippy::too_many_arguments)]
400async fn handle_http11_connection(
401    stream: tokio::net::TcpStream,
402    client: ClientType,
403    config: Arc<ConfigManager>,
404    metrics: SharedMetrics,
405    challenge_store: ChallengeStore,
406    lua_engine: OptionalLuaEngine,
407    circuit_breaker: SharedCircuitBreaker,
408    app_manager: Option<Arc<AppManager>>,
409) -> Result<()> {
410    let io = TokioIo::new(stream);
411    let svc = service_fn(move |req| {
412        handle_request(
413            req,
414            client.clone(),
415            config.clone(),
416            metrics.clone(),
417            challenge_store.clone(),
418            lua_engine.clone(),
419            circuit_breaker.clone(),
420            app_manager.clone(),
421        )
422    });
423
424    let conn = hyper::server::conn::http1::Builder::new()
425        .keep_alive(true)
426        .pipeline_flush(true)
427        .serve_connection(io, svc)
428        .with_upgrades();
429
430    if let Err(e) = conn.await {
431        tracing::debug!("HTTP/1.1 connection error: {}", e);
432    }
433
434    Ok(())
435}
436
437#[allow(clippy::too_many_arguments)]
438async fn handle_https2_connection(
439    stream: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
440    client: ClientType,
441    config: Arc<ConfigManager>,
442    metrics: SharedMetrics,
443    challenge_store: ChallengeStore,
444    lua_engine: OptionalLuaEngine,
445    circuit_breaker: SharedCircuitBreaker,
446    app_manager: Option<Arc<AppManager>>,
447) -> Result<()> {
448    let is_h2 = stream.get_ref().1.alpn_protocol() == Some(b"h2");
449
450    let io = TokioIo::new(stream);
451
452    if is_h2 {
453        let exec = TokioExecutor::new();
454        let svc = service_fn(move |req| {
455            handle_request(
456                req,
457                client.clone(),
458                config.clone(),
459                metrics.clone(),
460                challenge_store.clone(),
461                lua_engine.clone(),
462                circuit_breaker.clone(),
463                app_manager.clone(),
464            )
465        });
466        let conn = hyper::server::conn::http2::Builder::new(exec)
467            .initial_stream_window_size(1024 * 1024)
468            .initial_connection_window_size(2 * 1024 * 1024)
469            .max_concurrent_streams(250)
470            .serve_connection(io, svc);
471        if let Err(e) = conn.await {
472            tracing::debug!("HTTPS/2 connection error: {}", e);
473        }
474    } else {
475        let svc = service_fn(move |req| {
476            handle_request(
477                req,
478                client.clone(),
479                config.clone(),
480                metrics.clone(),
481                challenge_store.clone(),
482                lua_engine.clone(),
483                circuit_breaker.clone(),
484                app_manager.clone(),
485            )
486        });
487        let conn = hyper::server::conn::http1::Builder::new()
488            .keep_alive(true)
489            .pipeline_flush(true)
490            .serve_connection(io, svc)
491            .with_upgrades();
492        if let Err(e) = conn.await {
493            tracing::debug!("HTTPS/1.1 connection error: {}", e);
494        }
495    }
496
497    Ok(())
498}
499
500/// Extract headers from a hyper request into a HashMap for Lua consumption.
501#[cfg(feature = "scripting")]
502fn extract_headers(req: &Request<Incoming>) -> std::collections::HashMap<String, String> {
503    req.headers()
504        .iter()
505        .map(|(k, v)| {
506            (
507                k.as_str().to_lowercase(),
508                v.to_str().unwrap_or("").to_string(),
509            )
510        })
511        .collect()
512}
513
514/// Build a LuaRequest from a hyper Request.
515#[cfg(feature = "scripting")]
516fn build_lua_request(req: &Request<Incoming>) -> LuaRequest {
517    let host = req
518        .uri()
519        .host()
520        .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
521        .unwrap_or("")
522        .to_string();
523
524    let content_length = req
525        .headers()
526        .get("content-length")
527        .and_then(|v| v.to_str().ok())
528        .and_then(|v| v.parse().ok())
529        .unwrap_or(0);
530
531    LuaRequest {
532        method: req.method().to_string(),
533        path: req.uri().path().to_string(),
534        headers: extract_headers(req),
535        host,
536        content_length,
537    }
538}
539
540/// Extract response headers into a HashMap for Lua consumption.
541#[cfg(feature = "scripting")]
542fn extract_response_headers(
543    headers: &hyper::HeaderMap,
544) -> std::collections::HashMap<String, String> {
545    headers
546        .iter()
547        .map(|(k, v)| {
548            (
549                k.as_str().to_lowercase(),
550                v.to_str().unwrap_or("").to_string(),
551            )
552        })
553        .collect()
554}
555
556#[allow(clippy::too_many_arguments)]
557async fn handle_request(
558    req: Request<Incoming>,
559    client: ClientType,
560    config_manager: Arc<ConfigManager>,
561    metrics: SharedMetrics,
562    challenge_store: ChallengeStore,
563    lua_engine: OptionalLuaEngine,
564    circuit_breaker: SharedCircuitBreaker,
565    app_manager: Option<Arc<AppManager>>,
566) -> Result<Response<BoxBody>, hyper::Error> {
567    let start_time = std::time::Instant::now();
568    metrics.inc_in_flight();
569    let config = config_manager.get_config();
570
571    // ACME challenge check — must come before all other routing
572    if let Some(response) = handle_acme_challenge(&req, &challenge_store) {
573        metrics.dec_in_flight();
574        return Ok(response);
575    }
576
577    if is_metrics_request(&req) {
578        let duration = start_time.elapsed();
579        metrics.dec_in_flight();
580        let metrics_output = metrics.format_metrics();
581        metrics.record_request(0, metrics_output.len() as u64, 200, duration);
582        let body = http_body_util::Full::new(Bytes::from(metrics_output)).boxed();
583        return Ok(Response::builder()
584            .status(200)
585            .header("Content-Type", "text/plain")
586            .body(body)
587            .unwrap());
588    }
589
590    // --- Lua on_request hook ---
591    #[cfg(feature = "scripting")]
592    if let Some(ref engine) = lua_engine {
593        if engine.has_on_request() {
594            let mut lua_req = build_lua_request(&req);
595            match engine.call_on_request(&mut lua_req) {
596                RequestHookResult::Deny { status, body } => {
597                    metrics.dec_in_flight();
598                    let duration = start_time.elapsed();
599                    metrics.record_request(0, body.len() as u64, status, duration);
600                    let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
601                    return Ok(Response::builder().status(status).body(resp_body).unwrap());
602                }
603                RequestHookResult::Continue(updated_req) => {
604                    // Apply any header modifications back to the hyper request
605                    // We can't easily mutate the incoming request headers here since
606                    // we'd need to own it, so we store the lua_req for later use.
607                    // Headers set via set_header in on_request will be applied after
608                    // the request is decomposed into parts.
609                    let _ = updated_req;
610                }
611            }
612        }
613    }
614
615    let is_websocket = is_websocket_request(&req);
616
617    if is_websocket {
618        return handle_websocket_request(
619            req,
620            client,
621            &config,
622            &metrics,
623            start_time,
624            app_manager.clone(),
625        )
626        .await;
627    }
628
629    let result = handle_regular_request(
630        req,
631        client,
632        &config,
633        &lua_engine,
634        &circuit_breaker,
635        app_manager.clone(),
636    )
637    .await;
638    let duration = start_time.elapsed();
639
640    metrics.dec_in_flight();
641
642    match result {
643        #[allow(unused_variables)]
644        Ok((response, _target_url, route_scripts)) => {
645            let status = response.status().as_u16();
646
647            // --- Lua on_request_end hooks (global + route) ---
648            #[cfg(feature = "scripting")]
649            if let Some(ref engine) = lua_engine {
650                let lua_req = LuaRequest {
651                    method: String::new(),
652                    path: String::new(),
653                    headers: std::collections::HashMap::new(),
654                    host: String::new(),
655                    content_length: 0,
656                };
657                let duration_ms = duration.as_secs_f64() * 1000.0;
658
659                // Global on_request_end
660                if engine.has_on_request_end() {
661                    engine.call_on_request_end(&lua_req, status, duration_ms, &_target_url);
662                }
663
664                // Route-specific on_request_end
665                for script_name in &route_scripts {
666                    engine.call_route_on_request_end(
667                        script_name,
668                        &lua_req,
669                        status,
670                        duration_ms,
671                        &_target_url,
672                    );
673                }
674            }
675
676            metrics.record_request(0, 0, status, duration);
677            record_app_metrics(&metrics, &app_manager, &_target_url, 0, 0, status, duration);
678            let (parts, body) = response.into_parts();
679            let boxed = body.map_err(|_| unreachable!()).boxed();
680            Ok(Response::from_parts(parts, boxed))
681        }
682        Err(e) => {
683            metrics.inc_errors();
684            Err(e)
685        }
686    }
687}
688
689fn is_websocket_request(req: &Request<Incoming>) -> bool {
690    if let Some(upgrade) = req.headers().get("upgrade") {
691        if upgrade == "websocket" {
692            return true;
693        }
694    }
695    false
696}
697
698fn is_metrics_request(req: &Request<Incoming>) -> bool {
699    req.uri().path() == "/metrics"
700}
701
702fn handle_acme_challenge(
703    req: &Request<Incoming>,
704    challenge_store: &ChallengeStore,
705) -> Option<Response<BoxBody>> {
706    let path = req.uri().path();
707    let prefix = "/.well-known/acme-challenge/";
708
709    if !path.starts_with(prefix) {
710        return None;
711    }
712
713    let token = &path[prefix.len()..];
714
715    if let Ok(store) = challenge_store.read() {
716        if let Some(key_auth) = store.get(token) {
717            let body = http_body_util::Full::new(Bytes::from(key_auth.clone())).boxed();
718            return Some(
719                Response::builder()
720                    .status(200)
721                    .header("Content-Type", "text/plain")
722                    .body(body)
723                    .unwrap(),
724            );
725        }
726    }
727
728    let body = http_body_util::Full::new(Bytes::from("Challenge not found")).boxed();
729    Some(Response::builder().status(404).body(body).unwrap())
730}
731
732async fn handle_websocket_request(
733    req: Request<Incoming>,
734    _client: ClientType,
735    config: &crate::config::Config,
736    metrics: &SharedMetrics,
737    _start_time: std::time::Instant,
738    _app_manager: Option<Arc<AppManager>>,
739) -> Result<Response<BoxBody>, hyper::Error> {
740    let target_result = find_target(&req, &config.rules);
741
742    if target_result.is_none() {
743        metrics.inc_errors();
744        let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
745        return Ok(Response::builder().status(421).body(body).unwrap());
746    }
747
748    let (target_url, _, _, _) = target_result.unwrap();
749
750    // Extract host:port from target URL (e.g. "http://127.0.0.1:3000/path" -> "127.0.0.1:3000")
751    let backend_addr = match url::Url::parse(&target_url) {
752        Ok(u) => format!(
753            "{}:{}",
754            u.host_str().unwrap_or("127.0.0.1"),
755            u.port().unwrap_or(80)
756        ),
757        Err(_) => {
758            metrics.inc_errors();
759            let body = http_body_util::Full::new(Bytes::from("Bad backend URL")).boxed();
760            return Ok(Response::builder().status(502).body(body).unwrap());
761        }
762    };
763
764    let path = req.uri().path().to_string();
765    let query = req
766        .uri()
767        .query()
768        .map(|q| format!("?{}", q))
769        .unwrap_or_default();
770
771    let ws_key = req
772        .headers()
773        .get("sec-websocket-key")
774        .and_then(|v| v.to_str().ok())
775        .unwrap_or("")
776        .to_string();
777    let ws_version = req
778        .headers()
779        .get("sec-websocket-version")
780        .and_then(|v| v.to_str().ok())
781        .unwrap_or("13")
782        .to_string();
783    let ws_protocol = req
784        .headers()
785        .get("sec-websocket-protocol")
786        .and_then(|v| v.to_str().ok())
787        .map(|s| s.to_string());
788    let host_header = req
789        .headers()
790        .get("host")
791        .and_then(|v| v.to_str().ok())
792        .unwrap_or(&backend_addr)
793        .to_string();
794
795    tracing::info!(
796        "WebSocket upgrade request to {}{}{}",
797        backend_addr,
798        path,
799        query
800    );
801
802    // Connect to the backend
803    let backend = match TcpStream::connect(&backend_addr).await {
804        Ok(s) => s,
805        Err(e) => {
806            tracing::error!("Failed to connect to backend for WebSocket: {}", e);
807            metrics.inc_errors();
808            let body = http_body_util::Full::new(Bytes::from("Backend not reachable")).boxed();
809            return Ok(Response::builder().status(502).body(body).unwrap());
810        }
811    };
812
813    // Send the upgrade request to the backend
814    let mut handshake = format!(
815        "GET {}{} HTTP/1.1\r\n\
816         Host: {}\r\n\
817         Upgrade: websocket\r\n\
818         Connection: Upgrade\r\n\
819         Sec-WebSocket-Key: {}\r\n\
820         Sec-WebSocket-Version: {}\r\n",
821        path, query, host_header, ws_key, ws_version,
822    );
823    if let Some(proto) = &ws_protocol {
824        handshake.push_str(&format!("Sec-WebSocket-Protocol: {}\r\n", proto));
825    }
826    handshake.push_str("\r\n");
827
828    let (mut backend_read, mut backend_write) = backend.into_split();
829    if let Err(e) = backend_write.write_all(handshake.as_bytes()).await {
830        tracing::error!("Failed to send WebSocket handshake to backend: {}", e);
831        metrics.inc_errors();
832        let body =
833            http_body_util::Full::new(Bytes::from("Failed to initiate WebSocket with backend"))
834                .boxed();
835        return Ok(Response::builder().status(502).body(body).unwrap());
836    }
837
838    // Read the backend's 101 response
839    let mut response_buf = vec![0u8; 4096];
840    let n = match tokio::io::AsyncReadExt::read(&mut backend_read, &mut response_buf).await {
841        Ok(n) if n > 0 => n,
842        _ => {
843            tracing::error!("No response from backend for WebSocket upgrade");
844            metrics.inc_errors();
845            let body = http_body_util::Full::new(Bytes::from(
846                "Backend did not respond to WebSocket upgrade",
847            ))
848            .boxed();
849            return Ok(Response::builder().status(502).body(body).unwrap());
850        }
851    };
852
853    let response_str = String::from_utf8_lossy(&response_buf[..n]);
854    if !response_str.contains("101") {
855        tracing::error!(
856            "Backend rejected WebSocket upgrade: {}",
857            response_str.lines().next().unwrap_or("")
858        );
859        metrics.inc_errors();
860        let body =
861            http_body_util::Full::new(Bytes::from("Backend rejected WebSocket upgrade")).boxed();
862        return Ok(Response::builder().status(502).body(body).unwrap());
863    }
864
865    // Extract headers from backend 101 response
866    let mut accept_key = String::new();
867    let mut resp_protocol = None;
868    for line in response_str.lines().skip(1) {
869        if line.trim().is_empty() {
870            break;
871        }
872        if let Some((name, value)) = line.split_once(':') {
873            let name_lower = name.trim().to_lowercase();
874            let value = value.trim().to_string();
875            if name_lower == "sec-websocket-accept" {
876                accept_key = value;
877            } else if name_lower == "sec-websocket-protocol" {
878                resp_protocol = Some(value);
879            }
880        }
881    }
882
883    // Use hyper::upgrade::on to get the client-side stream after we return 101
884    let client_upgrade = hyper::upgrade::on(req);
885
886    // Reunite the backend halves
887    let backend_stream = backend_read.reunite(backend_write).unwrap();
888
889    // Spawn the bidirectional copy task
890    tokio::spawn(async move {
891        match client_upgrade.await {
892            Ok(upgraded) => {
893                let mut client_stream = TokioIo::new(upgraded);
894                let (mut br, mut bw) = tokio::io::split(backend_stream);
895                let (mut cr, mut cw) = tokio::io::split(&mut client_stream);
896                let _ = tokio::join!(
897                    tokio::io::copy(&mut br, &mut cw),
898                    tokio::io::copy(&mut cr, &mut bw),
899                );
900            }
901            Err(e) => {
902                tracing::error!("WebSocket client upgrade failed: {}", e);
903            }
904        }
905    });
906
907    // Return 101 Switching Protocols to the client
908    let mut resp = Response::builder()
909        .status(101)
910        .header("Upgrade", "websocket")
911        .header("Connection", "Upgrade")
912        .header("Sec-WebSocket-Accept", accept_key);
913    if let Some(proto) = resp_protocol {
914        resp = resp.header("Sec-WebSocket-Protocol", proto);
915    }
916    Ok(resp
917        .body(http_body_util::Full::new(Bytes::new()).boxed())
918        .unwrap())
919}
920
921/// Returns (Response, target_url_for_logging, route_scripts)
922async fn handle_regular_request(
923    req: Request<Incoming>,
924    client: ClientType,
925    config: &crate::config::Config,
926    lua_engine: &OptionalLuaEngine,
927    circuit_breaker: &SharedCircuitBreaker,
928    _app_manager: Option<Arc<AppManager>>,
929) -> Result<(Response<BoxBody>, String, Vec<String>), hyper::Error> {
930    let route = find_matching_rule(&req, &config.rules);
931
932    match route {
933        #[allow(unused_mut, unused_variables)]
934        Some(matched_route) => {
935            let path = req.uri().path().to_string();
936            let from_domain_rule = matched_route.from_domain_rule;
937            let matched_prefix = matched_route.matched_prefix();
938
939            if !matched_route.auth.is_empty() && !verify_basic_auth(&req, &matched_route.auth) {
940                tracing::debug!("Basic auth failed for {}", req.uri().path());
941                return Ok((create_auth_required_response(), String::new(), vec![]));
942            }
943            let route_scripts = matched_route.route_scripts.clone();
944
945            // Select an available target via circuit breaker
946            let target_selection = select_target(&matched_route, &path, circuit_breaker);
947            let (mut target_url, base_url) = match target_selection {
948                Some((url, base)) => (url, base),
949                None => {
950                    // All targets are circuit-broken
951                    let body =
952                        http_body_util::Full::new(Bytes::from("Service Unavailable")).boxed();
953                    return Ok((
954                        Response::builder()
955                            .status(503)
956                            .body(body)
957                            .expect("Failed to build response"),
958                        String::new(),
959                        route_scripts,
960                    ));
961                }
962            };
963            // --- Lua route-specific on_request hooks ---
964            #[cfg(feature = "scripting")]
965            if let Some(ref engine) = lua_engine {
966                for script_name in &route_scripts {
967                    let mut lua_req = build_lua_request(&req);
968                    match engine.call_route_on_request(script_name, &mut lua_req) {
969                        RequestHookResult::Deny { status, body } => {
970                            let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
971                            return Ok((
972                                Response::builder().status(status).body(resp_body).unwrap(),
973                                target_url,
974                                route_scripts.clone(),
975                            ));
976                        }
977                        RequestHookResult::Continue(_) => {}
978                    }
979                }
980            }
981
982            // --- Lua on_route hook (global) ---
983            #[cfg(feature = "scripting")]
984            if let Some(ref engine) = lua_engine {
985                if engine.has_on_route() {
986                    let lua_req = build_lua_request(&req);
987                    match engine.call_on_route(&lua_req, &target_url) {
988                        RouteHookResult::Override(new_url) => {
989                            target_url = new_url;
990                        }
991                        RouteHookResult::Default => {}
992                    }
993                }
994                // Route-specific on_route hooks
995                for script_name in &route_scripts {
996                    let lua_req = build_lua_request(&req);
997                    match engine.call_route_on_route(script_name, &lua_req, &target_url) {
998                        RouteHookResult::Override(new_url) => {
999                            target_url = new_url;
1000                        }
1001                        RouteHookResult::Default => {}
1002                    }
1003                }
1004            }
1005
1006            // Only extract host_header when needed (domain rules only)
1007            let host_header = if from_domain_rule {
1008                req.uri()
1009                    .host()
1010                    .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1011                    .map(|s| s.to_string())
1012            } else {
1013                None
1014            };
1015
1016            let (mut parts, body) = req.into_parts();
1017
1018            // Move headers directly instead of cloning one by one
1019            let uri: hyper::Uri = target_url.parse().expect("valid URI");
1020            parts.uri = uri;
1021            parts.version = http::Version::HTTP_11;
1022            parts.extensions = http::Extensions::new();
1023
1024            let mut request = Request::from_parts(parts, body);
1025
1026            request
1027                .headers_mut()
1028                .insert("X-Forwarded-For", X_FORWARDED_FOR_VALUE.clone());
1029
1030            if from_domain_rule {
1031                if let Some(host) = host_header {
1032                    request
1033                        .headers_mut()
1034                        .insert("X-Forwarded-Host", host.parse().unwrap());
1035                }
1036            }
1037
1038            match client.request(request).await {
1039                Ok(response) => {
1040                    // --- Circuit breaker: record success or failure ---
1041                    let status_code = response.status().as_u16();
1042                    if circuit_breaker.is_failure_status(status_code) {
1043                        circuit_breaker.record_failure(&base_url);
1044                    } else {
1045                        circuit_breaker.record_success(&base_url);
1046                    }
1047
1048                    // --- Lua on_response hooks (global + route) ---
1049                    #[cfg(feature = "scripting")]
1050                    if let Some(ref engine) = lua_engine {
1051                        let has_global = engine.has_on_response();
1052                        let has_route = !route_scripts.is_empty();
1053
1054                        if has_global || has_route {
1055                            use crate::scripting::ResponseMod;
1056
1057                            let lua_req = LuaRequest {
1058                                method: String::new(),
1059                                path: String::new(),
1060                                headers: std::collections::HashMap::new(),
1061                                host: String::new(),
1062                                content_length: 0,
1063                            };
1064                            let resp_headers = extract_response_headers(response.headers());
1065                            let resp_status = response.status().as_u16();
1066
1067                            // Collect all mods: global first, then route scripts
1068                            let mut all_mods: Vec<ResponseMod> = Vec::new();
1069                            if has_global {
1070                                all_mods.push(engine.call_on_response(
1071                                    &lua_req,
1072                                    resp_status,
1073                                    &resp_headers,
1074                                ));
1075                            }
1076                            for script_name in &route_scripts {
1077                                all_mods.push(engine.call_route_on_response(
1078                                    script_name,
1079                                    &lua_req,
1080                                    resp_status,
1081                                    &resp_headers,
1082                                ));
1083                            }
1084
1085                            // Merge all mods
1086                            let mut merged = ResponseMod::default();
1087                            for mods in all_mods {
1088                                merged.set_headers.extend(mods.set_headers);
1089                                merged.remove_headers.extend(mods.remove_headers);
1090                                if mods.replace_body.is_some() {
1091                                    merged.replace_body = mods.replace_body;
1092                                }
1093                                if mods.override_status.is_some() {
1094                                    merged.override_status = mods.override_status;
1095                                }
1096                            }
1097
1098                            // Apply modifications if any
1099                            if !merged.set_headers.is_empty()
1100                                || !merged.remove_headers.is_empty()
1101                                || merged.replace_body.is_some()
1102                                || merged.override_status.is_some()
1103                            {
1104                                let (mut parts, body) = response.into_parts();
1105
1106                                if let Some(status) = merged.override_status {
1107                                    parts.status =
1108                                        hyper::StatusCode::from_u16(status).unwrap_or(parts.status);
1109                                }
1110
1111                                for name in &merged.remove_headers {
1112                                    if let Ok(header_name) =
1113                                        name.parse::<hyper::header::HeaderName>()
1114                                    {
1115                                        parts.headers.remove(header_name);
1116                                    }
1117                                }
1118
1119                                for (name, value) in &merged.set_headers {
1120                                    if let (Ok(header_name), Ok(header_value)) = (
1121                                        name.parse::<hyper::header::HeaderName>(),
1122                                        value.parse::<HeaderValue>(),
1123                                    ) {
1124                                        parts.headers.insert(header_name, header_value);
1125                                    }
1126                                }
1127
1128                                if let Some(new_body) = merged.replace_body {
1129                                    let new_bytes = Bytes::from(new_body);
1130                                    parts.headers.remove("content-length");
1131                                    parts.headers.insert(
1132                                        "content-length",
1133                                        new_bytes.len().to_string().parse().unwrap(),
1134                                    );
1135                                    let boxed = http_body_util::Full::new(new_bytes).boxed();
1136                                    return Ok((
1137                                        Response::from_parts(parts, boxed),
1138                                        target_url,
1139                                        route_scripts.clone(),
1140                                    ));
1141                                }
1142
1143                                let boxed = body.map_err(|_| unreachable!()).boxed();
1144                                return Ok((
1145                                    Response::from_parts(parts, boxed),
1146                                    target_url,
1147                                    route_scripts.clone(),
1148                                ));
1149                            }
1150                        }
1151                    }
1152
1153                    let is_html = response
1154                        .headers()
1155                        .get("content-type")
1156                        .and_then(|v| v.to_str().ok())
1157                        .map(|ct| ct.starts_with("text/html"))
1158                        .unwrap_or(false);
1159
1160                    if is_html {
1161                        if let Some(prefix) = matched_prefix {
1162                            let (parts, body) = response.into_parts();
1163                            let body_bytes = body
1164                                .collect()
1165                                .await
1166                                .map(|collected| collected.to_bytes())
1167                                .unwrap_or_default();
1168
1169                            // Decompress gzip/deflate body before rewriting
1170                            let is_gzip = parts
1171                                .headers
1172                                .get("content-encoding")
1173                                .and_then(|v| v.to_str().ok())
1174                                .map(|v| v.contains("gzip"))
1175                                .unwrap_or(false);
1176                            let is_deflate = parts
1177                                .headers
1178                                .get("content-encoding")
1179                                .and_then(|v| v.to_str().ok())
1180                                .map(|v| v.contains("deflate"))
1181                                .unwrap_or(false);
1182
1183                            let raw_bytes = if is_gzip {
1184                                use std::io::Read;
1185                                let mut decoder = flate2::read::GzDecoder::new(&body_bytes[..]);
1186                                let mut decoded = Vec::new();
1187                                decoder.read_to_end(&mut decoded).unwrap_or_default();
1188                                Bytes::from(decoded)
1189                            } else if is_deflate {
1190                                use std::io::Read;
1191                                let mut decoder =
1192                                    flate2::read::DeflateDecoder::new(&body_bytes[..]);
1193                                let mut decoded = Vec::new();
1194                                decoder.read_to_end(&mut decoded).unwrap_or_default();
1195                                Bytes::from(decoded)
1196                            } else {
1197                                body_bytes
1198                            };
1199
1200                            let html = String::from_utf8_lossy(&raw_bytes);
1201                            let rewritten = html
1202                                .replace("href=\"/", &format!("href=\"{}/", prefix))
1203                                .replace("src=\"/", &format!("src=\"{}/", prefix))
1204                                .replace("action=\"/", &format!("action=\"{}/", prefix));
1205                            let rewritten_bytes = Bytes::from(rewritten);
1206                            let mut parts = parts;
1207                            parts.headers.remove("content-encoding");
1208                            parts.headers.remove("content-length");
1209                            parts.headers.insert(
1210                                "content-length",
1211                                rewritten_bytes.len().to_string().parse().unwrap(),
1212                            );
1213                            let boxed = http_body_util::Full::new(rewritten_bytes).boxed();
1214                            return Ok((
1215                                Response::from_parts(parts, boxed),
1216                                target_url,
1217                                route_scripts.clone(),
1218                            ));
1219                        }
1220                    }
1221
1222                    let (parts, body) = response.into_parts();
1223                    let boxed = body.map_err(|_| unreachable!()).boxed();
1224                    Ok((
1225                        Response::from_parts(parts, boxed),
1226                        target_url,
1227                        route_scripts,
1228                    ))
1229                }
1230                Err(e) => {
1231                    circuit_breaker.record_failure(&base_url);
1232                    tracing::error!("Backend request failed: {} (target: {})", e, target_url);
1233                    let body = http_body_util::Full::new(Bytes::from("Bad Gateway")).boxed();
1234                    Ok((
1235                        Response::builder()
1236                            .status(502)
1237                            .body(body)
1238                            .expect("Failed to build response"),
1239                        target_url,
1240                        route_scripts,
1241                    ))
1242                }
1243            }
1244        }
1245        None => {
1246            // Suppress unused variable warning when scripting feature is disabled
1247            let _ = lua_engine;
1248            let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
1249            Ok((
1250                Response::builder()
1251                    .status(421)
1252                    .body(body)
1253                    .expect("Failed to build response"),
1254                String::new(),
1255                vec![],
1256            ))
1257        }
1258    }
1259}
1260
1261/// How the target URL is resolved from the matched route
1262enum UrlResolution {
1263    /// Domain, Default: append full request path
1264    AppendPath,
1265    /// DomainPath, Prefix: strip prefix, append suffix
1266    StripPrefix(String),
1267    /// Exact, Regex: use target URL as-is
1268    Identity,
1269}
1270
1271/// A matched routing rule with all the info needed to resolve a target URL
1272struct MatchedRoute<'a> {
1273    targets: &'a [crate::config::Target],
1274    from_domain_rule: bool,
1275    resolution: UrlResolution,
1276    route_scripts: Vec<String>,
1277    auth: Vec<crate::auth::BasicAuth>,
1278}
1279
1280impl<'a> MatchedRoute<'a> {
1281    fn matched_prefix(&self) -> Option<String> {
1282        match &self.resolution {
1283            UrlResolution::StripPrefix(prefix) => Some(prefix.trim_end_matches('/').to_string()),
1284            _ => None,
1285        }
1286    }
1287}
1288
1289/// Resolve a target URL based on the resolution strategy
1290fn resolve_target_url(
1291    target: &crate::config::Target,
1292    path: &str,
1293    resolution: &UrlResolution,
1294) -> String {
1295    let target_str = target.url.as_str();
1296    match resolution {
1297        UrlResolution::AppendPath => {
1298            if target_str.ends_with('/') {
1299                format!("{}{}", target_str, &path[1..])
1300            } else {
1301                format!("{}{}", target_str, path)
1302            }
1303        }
1304        UrlResolution::StripPrefix(prefix) => {
1305            let suffix = if path.len() >= prefix.len() {
1306                &path[prefix.len()..]
1307            } else {
1308                ""
1309            };
1310            format!("{}{}", target_str, suffix)
1311        }
1312        UrlResolution::Identity => target_str.to_owned(),
1313    }
1314}
1315
1316/// Pure routing: find which rule matches the request
1317fn find_matching_rule<'a>(
1318    req: &Request<Incoming>,
1319    rules: &'a [crate::config::ProxyRule],
1320) -> Option<MatchedRoute<'a>> {
1321    let host = req
1322        .uri()
1323        .host()
1324        .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1325        .map(|h| h.split(':').next().unwrap_or(h))?;
1326
1327    let path = req.uri().path();
1328
1329    for rule in rules {
1330        match &rule.matcher {
1331            crate::config::RuleMatcher::Domain(domain) => {
1332                if domain == host && !rule.targets.is_empty() {
1333                    return Some(MatchedRoute {
1334                        targets: &rule.targets,
1335                        from_domain_rule: true,
1336                        resolution: UrlResolution::AppendPath,
1337                        route_scripts: rule.scripts.clone(),
1338                        auth: rule.auth.clone(),
1339                    });
1340                }
1341            }
1342            crate::config::RuleMatcher::DomainPath(domain, path_prefix) => {
1343                if domain == host && !rule.targets.is_empty() {
1344                    let matches = path.starts_with(path_prefix)
1345                        || (path_prefix.ends_with('/')
1346                            && path == path_prefix.trim_end_matches('/'));
1347                    if matches {
1348                        return Some(MatchedRoute {
1349                            targets: &rule.targets,
1350                            from_domain_rule: true,
1351                            resolution: UrlResolution::StripPrefix(path_prefix.clone()),
1352                            route_scripts: rule.scripts.clone(),
1353                            auth: rule.auth.clone(),
1354                        });
1355                    }
1356                }
1357            }
1358            _ => {}
1359        }
1360    }
1361
1362    // Check specific rules (Exact, Prefix, Regex) before Default
1363    for rule in rules {
1364        match &rule.matcher {
1365            crate::config::RuleMatcher::Exact(exact) => {
1366                if path == exact && !rule.targets.is_empty() {
1367                    return Some(MatchedRoute {
1368                        targets: &rule.targets,
1369                        from_domain_rule: false,
1370                        resolution: UrlResolution::Identity,
1371                        route_scripts: rule.scripts.clone(),
1372                        auth: rule.auth.clone(),
1373                    });
1374                }
1375            }
1376            crate::config::RuleMatcher::Prefix(prefix) => {
1377                if !rule.targets.is_empty() {
1378                    // Match /db against prefix /db/ (path without trailing slash)
1379                    let matches = path.starts_with(prefix)
1380                        || (prefix.ends_with('/') && path == prefix.trim_end_matches('/'));
1381                    if matches {
1382                        return Some(MatchedRoute {
1383                            targets: &rule.targets,
1384                            from_domain_rule: false,
1385                            resolution: UrlResolution::StripPrefix(prefix.clone()),
1386                            route_scripts: rule.scripts.clone(),
1387                            auth: rule.auth.clone(),
1388                        });
1389                    }
1390                }
1391            }
1392            crate::config::RuleMatcher::Regex(ref rm) => {
1393                if rm.is_match(path) && !rule.targets.is_empty() {
1394                    return Some(MatchedRoute {
1395                        targets: &rule.targets,
1396                        from_domain_rule: false,
1397                        resolution: UrlResolution::Identity,
1398                        route_scripts: rule.scripts.clone(),
1399                        auth: rule.auth.clone(),
1400                    });
1401                }
1402            }
1403            _ => {}
1404        }
1405    }
1406
1407    // Fall back to Default rule
1408    for rule in rules {
1409        if let crate::config::RuleMatcher::Default = &rule.matcher {
1410            if !rule.targets.is_empty() {
1411                return Some(MatchedRoute {
1412                    targets: &rule.targets,
1413                    from_domain_rule: false,
1414                    resolution: UrlResolution::AppendPath,
1415                    route_scripts: rule.scripts.clone(),
1416                    auth: rule.auth.clone(),
1417                });
1418            }
1419        }
1420    }
1421
1422    None
1423}
1424
1425/// Select the first available target (circuit breaker aware).
1426/// Returns (resolved_url, base_url) for logging and circuit breaker tracking.
1427fn select_target(
1428    route: &MatchedRoute<'_>,
1429    path: &str,
1430    circuit_breaker: &crate::circuit_breaker::CircuitBreaker,
1431) -> Option<(String, String)> {
1432    for target in route.targets {
1433        let base_url = target.url.as_str().to_owned();
1434        if circuit_breaker.is_available(&base_url) {
1435            let resolved = resolve_target_url(target, path, &route.resolution);
1436            return Some((resolved, base_url));
1437        }
1438    }
1439    None
1440}
1441
1442/// Backward-compatible wrapper: returns (target_url, from_domain_rule, matched_prefix, route_scripts)
1443fn find_target(
1444    req: &Request<Incoming>,
1445    rules: &[crate::config::ProxyRule],
1446) -> Option<(String, bool, Option<String>, Vec<String>)> {
1447    let route = find_matching_rule(req, rules)?;
1448    let path = req.uri().path();
1449    let target = route.targets.first()?;
1450    let resolved = resolve_target_url(target, path, &route.resolution);
1451    let matched_prefix = route.matched_prefix();
1452    Some((
1453        resolved,
1454        route.from_domain_rule,
1455        matched_prefix,
1456        route.route_scripts,
1457    ))
1458}