Skip to main content

soli_proxy/server/
mod.rs

1// When scripting feature is disabled, OptionalLuaEngine = () and cloning it triggers warnings
2#![allow(clippy::let_unit_value, clippy::clone_on_copy, clippy::unit_arg)]
3
4use crate::acme::ChallengeStore;
5use crate::app::AppManager;
6use crate::circuit_breaker::SharedCircuitBreaker;
7use crate::config::ConfigManager;
8use crate::metrics::SharedMetrics;
9use crate::shutdown::ShutdownCoordinator;
10use anyhow::Result;
11use bytes::Bytes;
12use http_body_util::BodyExt;
13use hyper::body::Incoming;
14use hyper::header::HeaderValue;
15use hyper::service::service_fn;
16use hyper::Request;
17use hyper::Response;
18use hyper_util::client::legacy::connect::HttpConnector;
19use hyper_util::client::legacy::Client;
20use hyper_util::rt::TokioExecutor;
21use hyper_util::rt::TokioIo;
22use socket2::{Domain, Protocol, Socket, Type};
23use std::net::SocketAddr;
24use std::sync::Arc;
25use std::time::Duration;
26use tokio::io::AsyncWriteExt;
27use tokio::net::{TcpListener, TcpStream};
28use tokio_rustls::TlsAcceptor;
29
30#[cfg(feature = "scripting")]
31use crate::scripting::{LuaEngine, LuaRequest, RequestHookResult, RouteHookResult};
32
33type ClientType = Client<HttpConnector, Incoming>;
34type BoxBody = http_body_util::combinators::BoxBody<Bytes, std::convert::Infallible>;
35
36#[cfg(feature = "scripting")]
37type OptionalLuaEngine = Option<LuaEngine>;
38#[cfg(not(feature = "scripting"))]
39type OptionalLuaEngine = ();
40
41/// Helper to record app-specific metrics
42fn record_app_metrics(
43    metrics: &SharedMetrics,
44    app_manager: &Option<Arc<AppManager>>,
45    target_url: &str,
46    bytes_in: u64,
47    bytes_out: u64,
48    status: u16,
49    duration: Duration,
50) {
51    if let Some(ref manager) = app_manager {
52        if let Ok(url) = url::Url::parse(target_url) {
53            if let Some(port) = url.port() {
54                if let Some(app_name) = futures::executor::block_on(manager.get_app_name(port)) {
55                    metrics.record_app_request(&app_name, bytes_in, bytes_out, status, duration);
56                }
57            }
58        }
59    }
60}
61
62/// Pre-parsed header value for X-Forwarded-For to avoid parsing on every request
63static X_FORWARDED_FOR_VALUE: std::sync::LazyLock<HeaderValue> =
64    std::sync::LazyLock::new(|| HeaderValue::from_static("127.0.0.1"));
65
66fn create_listener(addr: SocketAddr) -> Result<TcpListener> {
67    let domain = if addr.is_ipv4() {
68        Domain::IPV4
69    } else {
70        Domain::IPV6
71    };
72    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
73    socket.set_reuse_address(true)?;
74    socket.set_reuse_port(true)?;
75    socket.set_nonblocking(true)?;
76    socket.bind(&addr.into())?;
77    socket.listen(8192)?;
78    let std_listener: std::net::TcpListener = socket.into();
79    Ok(TcpListener::from_std(std_listener)?)
80}
81
82fn create_client() -> ClientType {
83    let exec = TokioExecutor::new();
84    let mut connector = HttpConnector::new();
85    connector.set_nodelay(true);
86    connector.set_keepalive(Some(Duration::from_secs(30)));
87    connector.set_connect_timeout(Some(Duration::from_secs(5)));
88    Client::builder(exec)
89        .pool_max_idle_per_host(256)
90        .pool_idle_timeout(Duration::from_secs(60))
91        .build(connector)
92}
93
94pub struct ProxyServer {
95    config: Arc<ConfigManager>,
96    shutdown: ShutdownCoordinator,
97    tls_acceptor: Option<TlsAcceptor>,
98    https_addr: Option<SocketAddr>,
99    metrics: SharedMetrics,
100    challenge_store: ChallengeStore,
101    lua_engine: OptionalLuaEngine,
102    circuit_breaker: SharedCircuitBreaker,
103    app_manager: Option<Arc<AppManager>>,
104}
105
106impl ProxyServer {
107    pub fn new(
108        config: Arc<ConfigManager>,
109        shutdown: ShutdownCoordinator,
110        metrics: SharedMetrics,
111        challenge_store: ChallengeStore,
112        lua_engine: OptionalLuaEngine,
113        circuit_breaker: SharedCircuitBreaker,
114        app_manager: Option<Arc<AppManager>>,
115    ) -> Result<Self> {
116        Ok(Self {
117            config,
118            shutdown,
119            tls_acceptor: None,
120            https_addr: None,
121            metrics,
122            challenge_store,
123            lua_engine,
124            circuit_breaker,
125            app_manager,
126        })
127    }
128
129    #[allow(clippy::too_many_arguments)]
130    pub fn with_https(
131        config: Arc<ConfigManager>,
132        shutdown: ShutdownCoordinator,
133        tls_acceptor: TlsAcceptor,
134        https_addr: SocketAddr,
135        metrics: SharedMetrics,
136        challenge_store: ChallengeStore,
137        lua_engine: OptionalLuaEngine,
138        circuit_breaker: SharedCircuitBreaker,
139        app_manager: Option<Arc<AppManager>>,
140    ) -> Result<Self> {
141        Ok(Self {
142            config,
143            shutdown,
144            tls_acceptor: Some(tls_acceptor),
145            https_addr: Some(https_addr),
146            metrics,
147            challenge_store,
148            lua_engine,
149            circuit_breaker,
150            app_manager,
151        })
152    }
153
154    pub async fn run(&self) -> Result<()> {
155        let cfg = self.config.get_config();
156        let http_addr: SocketAddr = cfg.server.bind.parse()?;
157        let https_addr = self.https_addr;
158
159        let has_https = https_addr.is_some();
160        let num_listeners = std::thread::available_parallelism()
161            .map(|n| n.get())
162            .unwrap_or(4);
163
164        // Spawn N HTTP accept loops with SO_REUSEPORT
165        // Each listener gets its own client with its own connection pool to avoid contention
166        let app_manager = self.app_manager.clone();
167        for i in 0..num_listeners {
168            let config_clone = self.config.clone();
169            let shutdown_clone = self.shutdown.clone();
170            let metrics_clone = self.metrics.clone();
171            let challenge_store_clone = self.challenge_store.clone();
172            let lua_clone = self.lua_engine.clone();
173            let cb_clone = self.circuit_breaker.clone();
174            let am_clone = app_manager.clone();
175
176            tokio::spawn(async move {
177                if let Err(e) = run_http_server(
178                    http_addr,
179                    config_clone,
180                    shutdown_clone,
181                    metrics_clone,
182                    challenge_store_clone,
183                    lua_clone,
184                    cb_clone,
185                    am_clone,
186                )
187                .await
188                {
189                    tracing::error!("HTTP/1.1 server error (listener {}): {}", i, e);
190                }
191            });
192        }
193
194        if let Some(https_addr) = https_addr {
195            for i in 0..num_listeners {
196                let config_clone = self.config.clone();
197                let shutdown_clone = self.shutdown.clone();
198                let acceptor = self.tls_acceptor.as_ref().unwrap().clone();
199                let metrics_clone = self.metrics.clone();
200                let challenge_store_clone = self.challenge_store.clone();
201                let lua_clone = self.lua_engine.clone();
202                let cb_clone = self.circuit_breaker.clone();
203                let am_clone = app_manager.clone();
204
205                tokio::spawn(async move {
206                    if let Err(e) = run_https_server(
207                        https_addr,
208                        config_clone,
209                        shutdown_clone,
210                        acceptor,
211                        metrics_clone,
212                        challenge_store_clone,
213                        lua_clone,
214                        cb_clone,
215                        am_clone,
216                    )
217                    .await
218                    {
219                        tracing::error!("HTTPS/2 server error (listener {}): {}", i, e);
220                    }
221                });
222            }
223        }
224
225        tracing::info!(
226            "HTTP/1.1 server listening on {} ({} accept loops)",
227            http_addr,
228            num_listeners
229        );
230        if has_https {
231            tracing::info!(
232                "HTTPS/2 server listening on {} ({} accept loops)",
233                https_addr.unwrap(),
234                num_listeners
235            );
236        }
237
238        loop {
239            if self.shutdown.is_shutting_down() {
240                tracing::info!("Shutting down servers...");
241                break;
242            }
243            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
244        }
245
246        Ok(())
247    }
248}
249
250#[allow(clippy::too_many_arguments)]
251async fn run_http_server(
252    addr: SocketAddr,
253    config: Arc<ConfigManager>,
254    shutdown: ShutdownCoordinator,
255    metrics: SharedMetrics,
256    challenge_store: ChallengeStore,
257    lua_engine: OptionalLuaEngine,
258    circuit_breaker: SharedCircuitBreaker,
259    app_manager: Option<Arc<AppManager>>,
260) -> Result<()> {
261    let listener = create_listener(addr)?;
262    let client = create_client();
263
264    loop {
265        if shutdown.is_shutting_down() {
266            break;
267        }
268
269        match listener.accept().await {
270            Ok((stream, _)) => {
271                let _ = stream.set_nodelay(true);
272                let client = client.clone();
273                let config = config.clone();
274                let metrics = metrics.clone();
275                let cs = challenge_store.clone();
276                let lua = lua_engine.clone();
277                let cb = circuit_breaker.clone();
278                let am = app_manager.clone();
279                tokio::spawn(async move {
280                    if let Err(e) =
281                        handle_http11_connection(stream, client, config, metrics, cs, lua, cb, am)
282                            .await
283                    {
284                        tracing::debug!("HTTP/1.1 connection error: {}", e);
285                    }
286                });
287            }
288            Err(e) => {
289                tracing::error!("HTTP/1.1 accept error: {}", e);
290            }
291        }
292    }
293
294    Ok(())
295}
296
297#[allow(clippy::too_many_arguments)]
298async fn run_https_server(
299    addr: SocketAddr,
300    config: Arc<ConfigManager>,
301    shutdown: ShutdownCoordinator,
302    acceptor: TlsAcceptor,
303    metrics: SharedMetrics,
304    challenge_store: ChallengeStore,
305    lua_engine: OptionalLuaEngine,
306    circuit_breaker: SharedCircuitBreaker,
307    app_manager: Option<Arc<AppManager>>,
308) -> Result<()> {
309    let listener = create_listener(addr)?;
310    let client = create_client();
311
312    loop {
313        if shutdown.is_shutting_down() {
314            break;
315        }
316
317        match listener.accept().await {
318            Ok((stream, _)) => {
319                let _ = stream.set_nodelay(true);
320                let client = client.clone();
321                let config = config.clone();
322                let acceptor = acceptor.clone();
323                let metrics = metrics.clone();
324                let cs = challenge_store.clone();
325                let lua = lua_engine.clone();
326                let cb = circuit_breaker.clone();
327                let am = app_manager.clone();
328                tokio::spawn(async move {
329                    match acceptor.accept(stream).await {
330                        Ok(tls_stream) => {
331                            metrics.inc_tls_connections();
332                            if let Err(e) = handle_https2_connection(
333                                tls_stream, client, config, metrics, cs, lua, cb, am,
334                            )
335                            .await
336                            {
337                                tracing::debug!("HTTPS/2 connection error: {}", e);
338                            }
339                        }
340                        Err(e) => {
341                            tracing::error!("TLS accept error: {}", e);
342                        }
343                    }
344                });
345            }
346            Err(e) => {
347                tracing::error!("HTTPS/2 accept error: {}", e);
348            }
349        }
350    }
351
352    Ok(())
353}
354
355#[allow(clippy::too_many_arguments)]
356async fn handle_http11_connection(
357    stream: tokio::net::TcpStream,
358    client: ClientType,
359    config: Arc<ConfigManager>,
360    metrics: SharedMetrics,
361    challenge_store: ChallengeStore,
362    lua_engine: OptionalLuaEngine,
363    circuit_breaker: SharedCircuitBreaker,
364    app_manager: Option<Arc<AppManager>>,
365) -> Result<()> {
366    let io = TokioIo::new(stream);
367    let svc = service_fn(move |req| {
368        handle_request(
369            req,
370            client.clone(),
371            config.clone(),
372            metrics.clone(),
373            challenge_store.clone(),
374            lua_engine.clone(),
375            circuit_breaker.clone(),
376            app_manager.clone(),
377        )
378    });
379
380    let conn = hyper::server::conn::http1::Builder::new()
381        .keep_alive(true)
382        .pipeline_flush(true)
383        .serve_connection(io, svc)
384        .with_upgrades();
385
386    if let Err(e) = conn.await {
387        tracing::debug!("HTTP/1.1 connection error: {}", e);
388    }
389
390    Ok(())
391}
392
393#[allow(clippy::too_many_arguments)]
394async fn handle_https2_connection(
395    stream: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
396    client: ClientType,
397    config: Arc<ConfigManager>,
398    metrics: SharedMetrics,
399    challenge_store: ChallengeStore,
400    lua_engine: OptionalLuaEngine,
401    circuit_breaker: SharedCircuitBreaker,
402    app_manager: Option<Arc<AppManager>>,
403) -> Result<()> {
404    let is_h2 = stream.get_ref().1.alpn_protocol() == Some(b"h2");
405
406    let io = TokioIo::new(stream);
407
408    if is_h2 {
409        let exec = TokioExecutor::new();
410        let svc = service_fn(move |req| {
411            handle_request(
412                req,
413                client.clone(),
414                config.clone(),
415                metrics.clone(),
416                challenge_store.clone(),
417                lua_engine.clone(),
418                circuit_breaker.clone(),
419                app_manager.clone(),
420            )
421        });
422        let conn = hyper::server::conn::http2::Builder::new(exec)
423            .initial_stream_window_size(1024 * 1024)
424            .initial_connection_window_size(2 * 1024 * 1024)
425            .max_concurrent_streams(250)
426            .serve_connection(io, svc);
427        if let Err(e) = conn.await {
428            tracing::debug!("HTTPS/2 connection error: {}", e);
429        }
430    } else {
431        let svc = service_fn(move |req| {
432            handle_request(
433                req,
434                client.clone(),
435                config.clone(),
436                metrics.clone(),
437                challenge_store.clone(),
438                lua_engine.clone(),
439                circuit_breaker.clone(),
440                app_manager.clone(),
441            )
442        });
443        let conn = hyper::server::conn::http1::Builder::new()
444            .keep_alive(true)
445            .pipeline_flush(true)
446            .serve_connection(io, svc)
447            .with_upgrades();
448        if let Err(e) = conn.await {
449            tracing::debug!("HTTPS/1.1 connection error: {}", e);
450        }
451    }
452
453    Ok(())
454}
455
456/// Extract headers from a hyper request into a HashMap for Lua consumption.
457#[cfg(feature = "scripting")]
458fn extract_headers(req: &Request<Incoming>) -> std::collections::HashMap<String, String> {
459    req.headers()
460        .iter()
461        .map(|(k, v)| {
462            (
463                k.as_str().to_lowercase(),
464                v.to_str().unwrap_or("").to_string(),
465            )
466        })
467        .collect()
468}
469
470/// Build a LuaRequest from a hyper Request.
471#[cfg(feature = "scripting")]
472fn build_lua_request(req: &Request<Incoming>) -> LuaRequest {
473    let host = req
474        .uri()
475        .host()
476        .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
477        .unwrap_or("")
478        .to_string();
479
480    let content_length = req
481        .headers()
482        .get("content-length")
483        .and_then(|v| v.to_str().ok())
484        .and_then(|v| v.parse().ok())
485        .unwrap_or(0);
486
487    LuaRequest {
488        method: req.method().to_string(),
489        path: req.uri().path().to_string(),
490        headers: extract_headers(req),
491        host,
492        content_length,
493    }
494}
495
496/// Extract response headers into a HashMap for Lua consumption.
497#[cfg(feature = "scripting")]
498fn extract_response_headers(
499    headers: &hyper::HeaderMap,
500) -> std::collections::HashMap<String, String> {
501    headers
502        .iter()
503        .map(|(k, v)| {
504            (
505                k.as_str().to_lowercase(),
506                v.to_str().unwrap_or("").to_string(),
507            )
508        })
509        .collect()
510}
511
512#[allow(clippy::too_many_arguments)]
513async fn handle_request(
514    req: Request<Incoming>,
515    client: ClientType,
516    config_manager: Arc<ConfigManager>,
517    metrics: SharedMetrics,
518    challenge_store: ChallengeStore,
519    lua_engine: OptionalLuaEngine,
520    circuit_breaker: SharedCircuitBreaker,
521    app_manager: Option<Arc<AppManager>>,
522) -> Result<Response<BoxBody>, hyper::Error> {
523    let start_time = std::time::Instant::now();
524    metrics.inc_in_flight();
525    let config = config_manager.get_config();
526
527    // ACME challenge check — must come before all other routing
528    if let Some(response) = handle_acme_challenge(&req, &challenge_store) {
529        metrics.dec_in_flight();
530        return Ok(response);
531    }
532
533    if is_metrics_request(&req) {
534        let duration = start_time.elapsed();
535        metrics.dec_in_flight();
536        let metrics_output = metrics.format_metrics();
537        metrics.record_request(0, metrics_output.len() as u64, 200, duration);
538        let body = http_body_util::Full::new(Bytes::from(metrics_output)).boxed();
539        return Ok(Response::builder()
540            .status(200)
541            .header("Content-Type", "text/plain")
542            .body(body)
543            .unwrap());
544    }
545
546    // --- Lua on_request hook ---
547    #[cfg(feature = "scripting")]
548    if let Some(ref engine) = lua_engine {
549        if engine.has_on_request() {
550            let mut lua_req = build_lua_request(&req);
551            match engine.call_on_request(&mut lua_req) {
552                RequestHookResult::Deny { status, body } => {
553                    metrics.dec_in_flight();
554                    let duration = start_time.elapsed();
555                    metrics.record_request(0, body.len() as u64, status, duration);
556                    let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
557                    return Ok(Response::builder().status(status).body(resp_body).unwrap());
558                }
559                RequestHookResult::Continue(updated_req) => {
560                    // Apply any header modifications back to the hyper request
561                    // We can't easily mutate the incoming request headers here since
562                    // we'd need to own it, so we store the lua_req for later use.
563                    // Headers set via set_header in on_request will be applied after
564                    // the request is decomposed into parts.
565                    let _ = updated_req;
566                }
567            }
568        }
569    }
570
571    let is_websocket = is_websocket_request(&req);
572
573    if is_websocket {
574        return handle_websocket_request(
575            req,
576            client,
577            &config,
578            &metrics,
579            start_time,
580            app_manager.clone(),
581        )
582        .await;
583    }
584
585    let result = handle_regular_request(
586        req,
587        client,
588        &config,
589        &lua_engine,
590        &circuit_breaker,
591        app_manager.clone(),
592    )
593    .await;
594    let duration = start_time.elapsed();
595
596    metrics.dec_in_flight();
597
598    match result {
599        #[allow(unused_variables)]
600        Ok((response, _target_url, route_scripts)) => {
601            let status = response.status().as_u16();
602
603            // --- Lua on_request_end hooks (global + route) ---
604            #[cfg(feature = "scripting")]
605            if let Some(ref engine) = lua_engine {
606                let lua_req = LuaRequest {
607                    method: String::new(),
608                    path: String::new(),
609                    headers: std::collections::HashMap::new(),
610                    host: String::new(),
611                    content_length: 0,
612                };
613                let duration_ms = duration.as_secs_f64() * 1000.0;
614
615                // Global on_request_end
616                if engine.has_on_request_end() {
617                    engine.call_on_request_end(&lua_req, status, duration_ms, &_target_url);
618                }
619
620                // Route-specific on_request_end
621                for script_name in &route_scripts {
622                    engine.call_route_on_request_end(
623                        script_name,
624                        &lua_req,
625                        status,
626                        duration_ms,
627                        &_target_url,
628                    );
629                }
630            }
631
632            metrics.record_request(0, 0, status, duration);
633            record_app_metrics(&metrics, &app_manager, &_target_url, 0, 0, status, duration);
634            let (parts, body) = response.into_parts();
635            let boxed = body.map_err(|_| unreachable!()).boxed();
636            Ok(Response::from_parts(parts, boxed))
637        }
638        Err(e) => {
639            metrics.inc_errors();
640            Err(e)
641        }
642    }
643}
644
645fn is_websocket_request(req: &Request<Incoming>) -> bool {
646    if let Some(upgrade) = req.headers().get("upgrade") {
647        if upgrade == "websocket" {
648            return true;
649        }
650    }
651    false
652}
653
654fn is_metrics_request(req: &Request<Incoming>) -> bool {
655    req.uri().path() == "/metrics"
656}
657
658fn handle_acme_challenge(
659    req: &Request<Incoming>,
660    challenge_store: &ChallengeStore,
661) -> Option<Response<BoxBody>> {
662    let path = req.uri().path();
663    let prefix = "/.well-known/acme-challenge/";
664
665    if !path.starts_with(prefix) {
666        return None;
667    }
668
669    let token = &path[prefix.len()..];
670
671    if let Ok(store) = challenge_store.read() {
672        if let Some(key_auth) = store.get(token) {
673            let body = http_body_util::Full::new(Bytes::from(key_auth.clone())).boxed();
674            return Some(
675                Response::builder()
676                    .status(200)
677                    .header("Content-Type", "text/plain")
678                    .body(body)
679                    .unwrap(),
680            );
681        }
682    }
683
684    let body = http_body_util::Full::new(Bytes::from("Challenge not found")).boxed();
685    Some(Response::builder().status(404).body(body).unwrap())
686}
687
688async fn handle_websocket_request(
689    req: Request<Incoming>,
690    _client: ClientType,
691    config: &crate::config::Config,
692    metrics: &SharedMetrics,
693    _start_time: std::time::Instant,
694    _app_manager: Option<Arc<AppManager>>,
695) -> Result<Response<BoxBody>, hyper::Error> {
696    let target_result = find_target(&req, &config.rules);
697
698    if target_result.is_none() {
699        metrics.inc_errors();
700        let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
701        return Ok(Response::builder().status(421).body(body).unwrap());
702    }
703
704    let (target_url, _, _, _) = target_result.unwrap();
705
706    // Extract host:port from target URL (e.g. "http://127.0.0.1:3000/path" -> "127.0.0.1:3000")
707    let backend_addr = match url::Url::parse(&target_url) {
708        Ok(u) => format!(
709            "{}:{}",
710            u.host_str().unwrap_or("127.0.0.1"),
711            u.port().unwrap_or(80)
712        ),
713        Err(_) => {
714            metrics.inc_errors();
715            let body = http_body_util::Full::new(Bytes::from("Bad backend URL")).boxed();
716            return Ok(Response::builder().status(502).body(body).unwrap());
717        }
718    };
719
720    let path = req.uri().path().to_string();
721    let query = req
722        .uri()
723        .query()
724        .map(|q| format!("?{}", q))
725        .unwrap_or_default();
726
727    let ws_key = req
728        .headers()
729        .get("sec-websocket-key")
730        .and_then(|v| v.to_str().ok())
731        .unwrap_or("")
732        .to_string();
733    let ws_version = req
734        .headers()
735        .get("sec-websocket-version")
736        .and_then(|v| v.to_str().ok())
737        .unwrap_or("13")
738        .to_string();
739    let ws_protocol = req
740        .headers()
741        .get("sec-websocket-protocol")
742        .and_then(|v| v.to_str().ok())
743        .map(|s| s.to_string());
744    let host_header = req
745        .headers()
746        .get("host")
747        .and_then(|v| v.to_str().ok())
748        .unwrap_or(&backend_addr)
749        .to_string();
750
751    tracing::info!(
752        "WebSocket upgrade request to {}{}{}",
753        backend_addr,
754        path,
755        query
756    );
757
758    // Connect to the backend
759    let backend = match TcpStream::connect(&backend_addr).await {
760        Ok(s) => s,
761        Err(e) => {
762            tracing::error!("Failed to connect to backend for WebSocket: {}", e);
763            metrics.inc_errors();
764            let body = http_body_util::Full::new(Bytes::from("Backend not reachable")).boxed();
765            return Ok(Response::builder().status(502).body(body).unwrap());
766        }
767    };
768
769    // Send the upgrade request to the backend
770    let mut handshake = format!(
771        "GET {}{} HTTP/1.1\r\n\
772         Host: {}\r\n\
773         Upgrade: websocket\r\n\
774         Connection: Upgrade\r\n\
775         Sec-WebSocket-Key: {}\r\n\
776         Sec-WebSocket-Version: {}\r\n",
777        path, query, host_header, ws_key, ws_version,
778    );
779    if let Some(proto) = &ws_protocol {
780        handshake.push_str(&format!("Sec-WebSocket-Protocol: {}\r\n", proto));
781    }
782    handshake.push_str("\r\n");
783
784    let (mut backend_read, mut backend_write) = backend.into_split();
785    if let Err(e) = backend_write.write_all(handshake.as_bytes()).await {
786        tracing::error!("Failed to send WebSocket handshake to backend: {}", e);
787        metrics.inc_errors();
788        let body =
789            http_body_util::Full::new(Bytes::from("Failed to initiate WebSocket with backend"))
790                .boxed();
791        return Ok(Response::builder().status(502).body(body).unwrap());
792    }
793
794    // Read the backend's 101 response
795    let mut response_buf = vec![0u8; 4096];
796    let n = match tokio::io::AsyncReadExt::read(&mut backend_read, &mut response_buf).await {
797        Ok(n) if n > 0 => n,
798        _ => {
799            tracing::error!("No response from backend for WebSocket upgrade");
800            metrics.inc_errors();
801            let body = http_body_util::Full::new(Bytes::from(
802                "Backend did not respond to WebSocket upgrade",
803            ))
804            .boxed();
805            return Ok(Response::builder().status(502).body(body).unwrap());
806        }
807    };
808
809    let response_str = String::from_utf8_lossy(&response_buf[..n]);
810    if !response_str.contains("101") {
811        tracing::error!(
812            "Backend rejected WebSocket upgrade: {}",
813            response_str.lines().next().unwrap_or("")
814        );
815        metrics.inc_errors();
816        let body =
817            http_body_util::Full::new(Bytes::from("Backend rejected WebSocket upgrade")).boxed();
818        return Ok(Response::builder().status(502).body(body).unwrap());
819    }
820
821    // Extract headers from backend 101 response
822    let mut accept_key = String::new();
823    let mut resp_protocol = None;
824    for line in response_str.lines().skip(1) {
825        if line.trim().is_empty() {
826            break;
827        }
828        if let Some((name, value)) = line.split_once(':') {
829            let name_lower = name.trim().to_lowercase();
830            let value = value.trim().to_string();
831            if name_lower == "sec-websocket-accept" {
832                accept_key = value;
833            } else if name_lower == "sec-websocket-protocol" {
834                resp_protocol = Some(value);
835            }
836        }
837    }
838
839    // Use hyper::upgrade::on to get the client-side stream after we return 101
840    let client_upgrade = hyper::upgrade::on(req);
841
842    // Reunite the backend halves
843    let backend_stream = backend_read.reunite(backend_write).unwrap();
844
845    // Spawn the bidirectional copy task
846    tokio::spawn(async move {
847        match client_upgrade.await {
848            Ok(upgraded) => {
849                let mut client_stream = TokioIo::new(upgraded);
850                let (mut br, mut bw) = tokio::io::split(backend_stream);
851                let (mut cr, mut cw) = tokio::io::split(&mut client_stream);
852                let _ = tokio::join!(
853                    tokio::io::copy(&mut br, &mut cw),
854                    tokio::io::copy(&mut cr, &mut bw),
855                );
856            }
857            Err(e) => {
858                tracing::error!("WebSocket client upgrade failed: {}", e);
859            }
860        }
861    });
862
863    // Return 101 Switching Protocols to the client
864    let mut resp = Response::builder()
865        .status(101)
866        .header("Upgrade", "websocket")
867        .header("Connection", "Upgrade")
868        .header("Sec-WebSocket-Accept", accept_key);
869    if let Some(proto) = resp_protocol {
870        resp = resp.header("Sec-WebSocket-Protocol", proto);
871    }
872    Ok(resp
873        .body(http_body_util::Full::new(Bytes::new()).boxed())
874        .unwrap())
875}
876
877/// Returns (Response, target_url_for_logging, route_scripts)
878async fn handle_regular_request(
879    req: Request<Incoming>,
880    client: ClientType,
881    config: &crate::config::Config,
882    lua_engine: &OptionalLuaEngine,
883    circuit_breaker: &SharedCircuitBreaker,
884    _app_manager: Option<Arc<AppManager>>,
885) -> Result<(Response<BoxBody>, String, Vec<String>), hyper::Error> {
886    let route = find_matching_rule(&req, &config.rules);
887
888    match route {
889        #[allow(unused_mut, unused_variables)]
890        Some(matched_route) => {
891            let path = req.uri().path().to_string();
892            let from_domain_rule = matched_route.from_domain_rule;
893            let matched_prefix = matched_route.matched_prefix();
894            let route_scripts = matched_route.route_scripts.clone();
895
896            // Select an available target via circuit breaker
897            let target_selection = select_target(&matched_route, &path, circuit_breaker);
898            let (mut target_url, base_url) = match target_selection {
899                Some((url, base)) => (url, base),
900                None => {
901                    // All targets are circuit-broken
902                    let body =
903                        http_body_util::Full::new(Bytes::from("Service Unavailable")).boxed();
904                    return Ok((
905                        Response::builder()
906                            .status(503)
907                            .body(body)
908                            .expect("Failed to build response"),
909                        String::new(),
910                        route_scripts,
911                    ));
912                }
913            };
914            // --- Lua route-specific on_request hooks ---
915            #[cfg(feature = "scripting")]
916            if let Some(ref engine) = lua_engine {
917                for script_name in &route_scripts {
918                    let mut lua_req = build_lua_request(&req);
919                    match engine.call_route_on_request(script_name, &mut lua_req) {
920                        RequestHookResult::Deny { status, body } => {
921                            let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
922                            return Ok((
923                                Response::builder().status(status).body(resp_body).unwrap(),
924                                target_url,
925                                route_scripts.clone(),
926                            ));
927                        }
928                        RequestHookResult::Continue(_) => {}
929                    }
930                }
931            }
932
933            // --- Lua on_route hook (global) ---
934            #[cfg(feature = "scripting")]
935            if let Some(ref engine) = lua_engine {
936                if engine.has_on_route() {
937                    let lua_req = build_lua_request(&req);
938                    match engine.call_on_route(&lua_req, &target_url) {
939                        RouteHookResult::Override(new_url) => {
940                            target_url = new_url;
941                        }
942                        RouteHookResult::Default => {}
943                    }
944                }
945                // Route-specific on_route hooks
946                for script_name in &route_scripts {
947                    let lua_req = build_lua_request(&req);
948                    match engine.call_route_on_route(script_name, &lua_req, &target_url) {
949                        RouteHookResult::Override(new_url) => {
950                            target_url = new_url;
951                        }
952                        RouteHookResult::Default => {}
953                    }
954                }
955            }
956
957            // Only extract host_header when needed (domain rules only)
958            let host_header = if from_domain_rule {
959                req.uri()
960                    .host()
961                    .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
962                    .map(|s| s.to_string())
963            } else {
964                None
965            };
966
967            let (mut parts, body) = req.into_parts();
968
969            // Move headers directly instead of cloning one by one
970            let uri: hyper::Uri = target_url.parse().expect("valid URI");
971            parts.uri = uri;
972            parts.version = http::Version::HTTP_11;
973            parts.extensions = http::Extensions::new();
974
975            let mut request = Request::from_parts(parts, body);
976
977            request
978                .headers_mut()
979                .insert("X-Forwarded-For", X_FORWARDED_FOR_VALUE.clone());
980
981            if from_domain_rule {
982                if let Some(host) = host_header {
983                    request
984                        .headers_mut()
985                        .insert("X-Forwarded-Host", host.parse().unwrap());
986                }
987            }
988
989            match client.request(request).await {
990                Ok(response) => {
991                    // --- Circuit breaker: record success or failure ---
992                    let status_code = response.status().as_u16();
993                    if circuit_breaker.is_failure_status(status_code) {
994                        circuit_breaker.record_failure(&base_url);
995                    } else {
996                        circuit_breaker.record_success(&base_url);
997                    }
998
999                    // --- Lua on_response hooks (global + route) ---
1000                    #[cfg(feature = "scripting")]
1001                    if let Some(ref engine) = lua_engine {
1002                        let has_global = engine.has_on_response();
1003                        let has_route = !route_scripts.is_empty();
1004
1005                        if has_global || has_route {
1006                            use crate::scripting::ResponseMod;
1007
1008                            let lua_req = LuaRequest {
1009                                method: String::new(),
1010                                path: String::new(),
1011                                headers: std::collections::HashMap::new(),
1012                                host: String::new(),
1013                                content_length: 0,
1014                            };
1015                            let resp_headers = extract_response_headers(response.headers());
1016                            let resp_status = response.status().as_u16();
1017
1018                            // Collect all mods: global first, then route scripts
1019                            let mut all_mods: Vec<ResponseMod> = Vec::new();
1020                            if has_global {
1021                                all_mods.push(engine.call_on_response(
1022                                    &lua_req,
1023                                    resp_status,
1024                                    &resp_headers,
1025                                ));
1026                            }
1027                            for script_name in &route_scripts {
1028                                all_mods.push(engine.call_route_on_response(
1029                                    script_name,
1030                                    &lua_req,
1031                                    resp_status,
1032                                    &resp_headers,
1033                                ));
1034                            }
1035
1036                            // Merge all mods
1037                            let mut merged = ResponseMod::default();
1038                            for mods in all_mods {
1039                                merged.set_headers.extend(mods.set_headers);
1040                                merged.remove_headers.extend(mods.remove_headers);
1041                                if mods.replace_body.is_some() {
1042                                    merged.replace_body = mods.replace_body;
1043                                }
1044                                if mods.override_status.is_some() {
1045                                    merged.override_status = mods.override_status;
1046                                }
1047                            }
1048
1049                            // Apply modifications if any
1050                            if !merged.set_headers.is_empty()
1051                                || !merged.remove_headers.is_empty()
1052                                || merged.replace_body.is_some()
1053                                || merged.override_status.is_some()
1054                            {
1055                                let (mut parts, body) = response.into_parts();
1056
1057                                if let Some(status) = merged.override_status {
1058                                    parts.status =
1059                                        hyper::StatusCode::from_u16(status).unwrap_or(parts.status);
1060                                }
1061
1062                                for name in &merged.remove_headers {
1063                                    if let Ok(header_name) =
1064                                        name.parse::<hyper::header::HeaderName>()
1065                                    {
1066                                        parts.headers.remove(header_name);
1067                                    }
1068                                }
1069
1070                                for (name, value) in &merged.set_headers {
1071                                    if let (Ok(header_name), Ok(header_value)) = (
1072                                        name.parse::<hyper::header::HeaderName>(),
1073                                        value.parse::<HeaderValue>(),
1074                                    ) {
1075                                        parts.headers.insert(header_name, header_value);
1076                                    }
1077                                }
1078
1079                                if let Some(new_body) = merged.replace_body {
1080                                    let new_bytes = Bytes::from(new_body);
1081                                    parts.headers.remove("content-length");
1082                                    parts.headers.insert(
1083                                        "content-length",
1084                                        new_bytes.len().to_string().parse().unwrap(),
1085                                    );
1086                                    let boxed = http_body_util::Full::new(new_bytes).boxed();
1087                                    return Ok((
1088                                        Response::from_parts(parts, boxed),
1089                                        target_url,
1090                                        route_scripts.clone(),
1091                                    ));
1092                                }
1093
1094                                let boxed = body.map_err(|_| unreachable!()).boxed();
1095                                return Ok((
1096                                    Response::from_parts(parts, boxed),
1097                                    target_url,
1098                                    route_scripts.clone(),
1099                                ));
1100                            }
1101                        }
1102                    }
1103
1104                    let is_html = response
1105                        .headers()
1106                        .get("content-type")
1107                        .and_then(|v| v.to_str().ok())
1108                        .map(|ct| ct.starts_with("text/html"))
1109                        .unwrap_or(false);
1110
1111                    if is_html {
1112                        if let Some(prefix) = matched_prefix {
1113                            let (parts, body) = response.into_parts();
1114                            let body_bytes = body
1115                                .collect()
1116                                .await
1117                                .map(|collected| collected.to_bytes())
1118                                .unwrap_or_default();
1119
1120                            // Decompress gzip/deflate body before rewriting
1121                            let is_gzip = parts
1122                                .headers
1123                                .get("content-encoding")
1124                                .and_then(|v| v.to_str().ok())
1125                                .map(|v| v.contains("gzip"))
1126                                .unwrap_or(false);
1127                            let is_deflate = parts
1128                                .headers
1129                                .get("content-encoding")
1130                                .and_then(|v| v.to_str().ok())
1131                                .map(|v| v.contains("deflate"))
1132                                .unwrap_or(false);
1133
1134                            let raw_bytes = if is_gzip {
1135                                use std::io::Read;
1136                                let mut decoder = flate2::read::GzDecoder::new(&body_bytes[..]);
1137                                let mut decoded = Vec::new();
1138                                decoder.read_to_end(&mut decoded).unwrap_or_default();
1139                                Bytes::from(decoded)
1140                            } else if is_deflate {
1141                                use std::io::Read;
1142                                let mut decoder =
1143                                    flate2::read::DeflateDecoder::new(&body_bytes[..]);
1144                                let mut decoded = Vec::new();
1145                                decoder.read_to_end(&mut decoded).unwrap_or_default();
1146                                Bytes::from(decoded)
1147                            } else {
1148                                body_bytes
1149                            };
1150
1151                            let html = String::from_utf8_lossy(&raw_bytes);
1152                            let rewritten = html
1153                                .replace("href=\"/", &format!("href=\"{}/", prefix))
1154                                .replace("src=\"/", &format!("src=\"{}/", prefix))
1155                                .replace("action=\"/", &format!("action=\"{}/", prefix));
1156                            let rewritten_bytes = Bytes::from(rewritten);
1157                            let mut parts = parts;
1158                            parts.headers.remove("content-encoding");
1159                            parts.headers.remove("content-length");
1160                            parts.headers.insert(
1161                                "content-length",
1162                                rewritten_bytes.len().to_string().parse().unwrap(),
1163                            );
1164                            let boxed = http_body_util::Full::new(rewritten_bytes).boxed();
1165                            return Ok((
1166                                Response::from_parts(parts, boxed),
1167                                target_url,
1168                                route_scripts.clone(),
1169                            ));
1170                        }
1171                    }
1172
1173                    let (parts, body) = response.into_parts();
1174                    let boxed = body.map_err(|_| unreachable!()).boxed();
1175                    Ok((
1176                        Response::from_parts(parts, boxed),
1177                        target_url,
1178                        route_scripts,
1179                    ))
1180                }
1181                Err(e) => {
1182                    circuit_breaker.record_failure(&base_url);
1183                    tracing::error!("Backend request failed: {} (target: {})", e, target_url);
1184                    let body = http_body_util::Full::new(Bytes::from("Bad Gateway")).boxed();
1185                    Ok((
1186                        Response::builder()
1187                            .status(502)
1188                            .body(body)
1189                            .expect("Failed to build response"),
1190                        target_url,
1191                        route_scripts,
1192                    ))
1193                }
1194            }
1195        }
1196        None => {
1197            // Suppress unused variable warning when scripting feature is disabled
1198            let _ = lua_engine;
1199            let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
1200            Ok((
1201                Response::builder()
1202                    .status(421)
1203                    .body(body)
1204                    .expect("Failed to build response"),
1205                String::new(),
1206                vec![],
1207            ))
1208        }
1209    }
1210}
1211
1212/// How the target URL is resolved from the matched route
1213enum UrlResolution {
1214    /// Domain, Default: append full request path
1215    AppendPath,
1216    /// DomainPath, Prefix: strip prefix, append suffix
1217    StripPrefix(String),
1218    /// Exact, Regex: use target URL as-is
1219    Identity,
1220}
1221
1222/// A matched routing rule with all the info needed to resolve a target URL
1223struct MatchedRoute<'a> {
1224    targets: &'a [crate::config::Target],
1225    from_domain_rule: bool,
1226    resolution: UrlResolution,
1227    route_scripts: Vec<String>,
1228}
1229
1230impl<'a> MatchedRoute<'a> {
1231    fn matched_prefix(&self) -> Option<String> {
1232        match &self.resolution {
1233            UrlResolution::StripPrefix(prefix) => Some(prefix.trim_end_matches('/').to_string()),
1234            _ => None,
1235        }
1236    }
1237}
1238
1239/// Resolve a target URL based on the resolution strategy
1240fn resolve_target_url(
1241    target: &crate::config::Target,
1242    path: &str,
1243    resolution: &UrlResolution,
1244) -> String {
1245    let target_str = target.url.as_str();
1246    match resolution {
1247        UrlResolution::AppendPath => {
1248            if target_str.ends_with('/') {
1249                format!("{}{}", target_str, &path[1..])
1250            } else {
1251                format!("{}{}", target_str, path)
1252            }
1253        }
1254        UrlResolution::StripPrefix(prefix) => {
1255            let suffix = if path.len() >= prefix.len() {
1256                &path[prefix.len()..]
1257            } else {
1258                ""
1259            };
1260            format!("{}{}", target_str, suffix)
1261        }
1262        UrlResolution::Identity => target_str.to_owned(),
1263    }
1264}
1265
1266/// Pure routing: find which rule matches the request
1267fn find_matching_rule<'a>(
1268    req: &Request<Incoming>,
1269    rules: &'a [crate::config::ProxyRule],
1270) -> Option<MatchedRoute<'a>> {
1271    let host = req
1272        .uri()
1273        .host()
1274        .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1275        .map(|h| h.split(':').next().unwrap_or(h))?;
1276
1277    let path = req.uri().path();
1278    let mut matched_domain = false;
1279
1280    for rule in rules {
1281        match &rule.matcher {
1282            crate::config::RuleMatcher::Domain(domain) => {
1283                if domain == host {
1284                    matched_domain = true;
1285                    if !rule.targets.is_empty() {
1286                        return Some(MatchedRoute {
1287                            targets: &rule.targets,
1288                            from_domain_rule: true,
1289                            resolution: UrlResolution::AppendPath,
1290                            route_scripts: rule.scripts.clone(),
1291                        });
1292                    }
1293                }
1294            }
1295            crate::config::RuleMatcher::DomainPath(domain, path_prefix) => {
1296                if domain == host && !rule.targets.is_empty() {
1297                    let matches = path.starts_with(path_prefix)
1298                        || (path_prefix.ends_with('/')
1299                            && path == path_prefix.trim_end_matches('/'));
1300                    if matches {
1301                        return Some(MatchedRoute {
1302                            targets: &rule.targets,
1303                            from_domain_rule: true,
1304                            resolution: UrlResolution::StripPrefix(path_prefix.clone()),
1305                            route_scripts: rule.scripts.clone(),
1306                        });
1307                    }
1308                }
1309            }
1310            _ => {}
1311        }
1312    }
1313
1314    if matched_domain {
1315        return None;
1316    }
1317
1318    // Check specific rules (Exact, Prefix, Regex) before Default
1319    for rule in rules {
1320        match &rule.matcher {
1321            crate::config::RuleMatcher::Exact(exact) => {
1322                if path == exact && !rule.targets.is_empty() {
1323                    return Some(MatchedRoute {
1324                        targets: &rule.targets,
1325                        from_domain_rule: false,
1326                        resolution: UrlResolution::Identity,
1327                        route_scripts: rule.scripts.clone(),
1328                    });
1329                }
1330            }
1331            crate::config::RuleMatcher::Prefix(prefix) => {
1332                if !rule.targets.is_empty() {
1333                    // Match /db against prefix /db/ (path without trailing slash)
1334                    let matches = path.starts_with(prefix)
1335                        || (prefix.ends_with('/') && path == prefix.trim_end_matches('/'));
1336                    if matches {
1337                        return Some(MatchedRoute {
1338                            targets: &rule.targets,
1339                            from_domain_rule: false,
1340                            resolution: UrlResolution::StripPrefix(prefix.clone()),
1341                            route_scripts: rule.scripts.clone(),
1342                        });
1343                    }
1344                }
1345            }
1346            crate::config::RuleMatcher::Regex(ref rm) => {
1347                if rm.is_match(path) && !rule.targets.is_empty() {
1348                    return Some(MatchedRoute {
1349                        targets: &rule.targets,
1350                        from_domain_rule: false,
1351                        resolution: UrlResolution::Identity,
1352                        route_scripts: rule.scripts.clone(),
1353                    });
1354                }
1355            }
1356            _ => {}
1357        }
1358    }
1359
1360    // Fall back to Default rule
1361    for rule in rules {
1362        if let crate::config::RuleMatcher::Default = &rule.matcher {
1363            if !rule.targets.is_empty() {
1364                return Some(MatchedRoute {
1365                    targets: &rule.targets,
1366                    from_domain_rule: false,
1367                    resolution: UrlResolution::AppendPath,
1368                    route_scripts: rule.scripts.clone(),
1369                });
1370            }
1371        }
1372    }
1373
1374    None
1375}
1376
1377/// Select the first available target (circuit breaker aware).
1378/// Returns (resolved_url, base_url) for logging and circuit breaker tracking.
1379fn select_target(
1380    route: &MatchedRoute<'_>,
1381    path: &str,
1382    circuit_breaker: &crate::circuit_breaker::CircuitBreaker,
1383) -> Option<(String, String)> {
1384    for target in route.targets {
1385        let base_url = target.url.as_str().to_owned();
1386        if circuit_breaker.is_available(&base_url) {
1387            let resolved = resolve_target_url(target, path, &route.resolution);
1388            return Some((resolved, base_url));
1389        }
1390    }
1391    None
1392}
1393
1394/// Backward-compatible wrapper: returns (target_url, from_domain_rule, matched_prefix, route_scripts)
1395fn find_target(
1396    req: &Request<Incoming>,
1397    rules: &[crate::config::ProxyRule],
1398) -> Option<(String, bool, Option<String>, Vec<String>)> {
1399    let route = find_matching_rule(req, rules)?;
1400    let path = req.uri().path();
1401    let target = route.targets.first()?;
1402    let resolved = resolve_target_url(target, path, &route.resolution);
1403    let matched_prefix = route.matched_prefix();
1404    Some((
1405        resolved,
1406        route.from_domain_rule,
1407        matched_prefix,
1408        route.route_scripts,
1409    ))
1410}