Skip to main content

soli_proxy/server/
mod.rs

1// When scripting feature is disabled, OptionalLuaEngine = () and cloning it triggers warnings
2#![allow(clippy::let_unit_value, clippy::clone_on_copy, clippy::unit_arg)]
3
4use crate::acme::ChallengeStore;
5use crate::app::AppManager;
6use crate::auth;
7use crate::circuit_breaker::SharedCircuitBreaker;
8use crate::config::ConfigManager;
9use crate::metrics::SharedMetrics;
10use crate::pool::{ConnectionPool, ProxyClient};
11use crate::shutdown::ShutdownCoordinator;
12use anyhow::Result;
13use bytes::Bytes;
14use http_body_util::BodyExt;
15use hyper::body::Incoming;
16use hyper::header::HeaderValue;
17use hyper::service::service_fn;
18use hyper::Request;
19use hyper::Response;
20use hyper_util::rt::TokioExecutor;
21use hyper_util::rt::TokioIo;
22use socket2::{Domain, Protocol, Socket, Type};
23use std::net::SocketAddr;
24use std::sync::atomic::{AtomicUsize, Ordering};
25use std::sync::Arc;
26use std::time::Duration;
27use tokio::io::AsyncWriteExt;
28use tokio::net::{TcpListener, TcpStream};
29use tokio_rustls::TlsAcceptor;
30
31#[cfg(feature = "scripting")]
32use crate::scripting::{LuaEngine, LuaRequest, RequestHookResult, RouteHookResult};
33
34type BoxBody = http_body_util::combinators::BoxBody<Bytes, std::convert::Infallible>;
35
36#[cfg(feature = "scripting")]
37type OptionalLuaEngine = Option<LuaEngine>;
38#[cfg(not(feature = "scripting"))]
39type OptionalLuaEngine = ();
40
41pub struct LoadBalancerState {
42    counters: Vec<AtomicUsize>,
43}
44
45impl LoadBalancerState {
46    pub fn new(num_rules: usize) -> Self {
47        Self {
48            counters: (0..num_rules).map(|_| AtomicUsize::new(0)).collect(),
49        }
50    }
51
52    pub fn select_index(&self, rule_idx: usize, num_targets: usize) -> usize {
53        if num_targets == 0 {
54            return 0;
55        }
56        self.counters[rule_idx].fetch_add(1, Ordering::Relaxed) % num_targets
57    }
58}
59
60/// Helper to record app-specific metrics
61async fn record_app_metrics(
62    metrics: &SharedMetrics,
63    app_manager: &Option<Arc<AppManager>>,
64    target_url: &str,
65    bytes_in: u64,
66    bytes_out: u64,
67    status: u16,
68    duration: Duration,
69) {
70    if let Some(ref manager) = app_manager {
71        if let Ok(url) = url::Url::parse(target_url) {
72            if let Some(port) = url.port() {
73                if let Some(app_name) = manager.get_app_name(port).await {
74                    metrics.record_app_request(&app_name, bytes_in, bytes_out, status, duration);
75                }
76            }
77        }
78    }
79}
80
81/// Pre-parsed header value for X-Forwarded-For to avoid parsing on every request
82static X_FORWARDED_FOR_VALUE: std::sync::LazyLock<HeaderValue> =
83    std::sync::LazyLock::new(|| HeaderValue::from_static("127.0.0.1"));
84
85/// Verify Basic Auth credentials against stored hashes
86/// Returns true if credentials are valid, false otherwise
87fn verify_basic_auth(req: &Request<Incoming>, auth_entries: &[crate::auth::BasicAuth]) -> bool {
88    if auth_entries.is_empty() {
89        return true;
90    }
91
92    let auth_header = req.headers().get("authorization");
93    if auth_header.is_none() {
94        return false;
95    }
96
97    let header_value = auth_header.unwrap().to_str().unwrap_or("");
98    if !header_value.starts_with("Basic ") {
99        return false;
100    }
101
102    let encoded = &header_value[6..];
103    let decoded = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, encoded)
104        .unwrap_or_default();
105    let creds = String::from_utf8_lossy(&decoded);
106
107    if let Some((username, password)) = creds.split_once(':') {
108        for entry in auth_entries {
109            if entry.username == username && auth::verify_password(password, &entry.hash) {
110                return true;
111            }
112        }
113    }
114
115    false
116}
117
118/// Create 401 Unauthorized response with WWW-Authenticate header
119fn create_auth_required_response() -> Response<BoxBody> {
120    let body = http_body_util::Full::new(Bytes::from("Authentication required")).boxed();
121    Response::builder()
122        .status(401)
123        .header("WWW-Authenticate", "Basic realm=\"Restricted\"")
124        .body(body)
125        .unwrap()
126}
127
128fn create_listener(addr: SocketAddr) -> Result<TcpListener> {
129    let domain = if addr.is_ipv4() {
130        Domain::IPV4
131    } else {
132        Domain::IPV6
133    };
134    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
135    socket.set_reuse_address(true)?;
136    socket.set_reuse_port(true)?;
137    socket.set_nonblocking(true)?;
138    socket.bind(&addr.into())?;
139    socket.listen(8192)?;
140    let std_listener: std::net::TcpListener = socket.into();
141    Ok(TcpListener::from_std(std_listener)?)
142}
143
144pub struct ProxyServer {
145    config: Arc<ConfigManager>,
146    shutdown: ShutdownCoordinator,
147    tls_acceptor: Option<TlsAcceptor>,
148    https_addr: Option<SocketAddr>,
149    metrics: SharedMetrics,
150    challenge_store: ChallengeStore,
151    lua_engine: OptionalLuaEngine,
152    circuit_breaker: SharedCircuitBreaker,
153    app_manager: Option<Arc<AppManager>>,
154    load_balancer: Arc<LoadBalancerState>,
155}
156
157impl ProxyServer {
158    pub fn new(
159        config: Arc<ConfigManager>,
160        shutdown: ShutdownCoordinator,
161        metrics: SharedMetrics,
162        challenge_store: ChallengeStore,
163        lua_engine: OptionalLuaEngine,
164        circuit_breaker: SharedCircuitBreaker,
165        app_manager: Option<Arc<AppManager>>,
166    ) -> Result<Self> {
167        let num_rules = config.get_config().rules.len();
168        Ok(Self {
169            config,
170            shutdown,
171            tls_acceptor: None,
172            https_addr: None,
173            metrics,
174            challenge_store,
175            lua_engine,
176            circuit_breaker,
177            app_manager,
178            load_balancer: Arc::new(LoadBalancerState::new(num_rules)),
179        })
180    }
181
182    #[allow(clippy::too_many_arguments)]
183    pub fn with_https(
184        config: Arc<ConfigManager>,
185        shutdown: ShutdownCoordinator,
186        tls_acceptor: TlsAcceptor,
187        https_addr: SocketAddr,
188        metrics: SharedMetrics,
189        challenge_store: ChallengeStore,
190        lua_engine: OptionalLuaEngine,
191        circuit_breaker: SharedCircuitBreaker,
192        app_manager: Option<Arc<AppManager>>,
193    ) -> Result<Self> {
194        let num_rules = config.get_config().rules.len();
195        Ok(Self {
196            config,
197            shutdown,
198            tls_acceptor: Some(tls_acceptor),
199            https_addr: Some(https_addr),
200            metrics,
201            challenge_store,
202            lua_engine,
203            circuit_breaker,
204            app_manager,
205            load_balancer: Arc::new(LoadBalancerState::new(num_rules)),
206        })
207    }
208
209    pub async fn run(&self) -> Result<()> {
210        let cfg = self.config.get_config();
211        let http_addr: SocketAddr = cfg.server.bind.parse()?;
212        let https_addr = self.https_addr;
213
214        let has_https = https_addr.is_some();
215        let num_listeners = std::thread::available_parallelism()
216            .map(|n| n.get())
217            .unwrap_or(4);
218
219        // Spawn N HTTP accept loops with SO_REUSEPORT
220        // Each listener gets its own client with its own connection pool to avoid contention
221        let app_manager = self.app_manager.clone();
222        for i in 0..num_listeners {
223            let config_clone = self.config.clone();
224            let shutdown_clone = self.shutdown.clone();
225            let metrics_clone = self.metrics.clone();
226            let challenge_store_clone = self.challenge_store.clone();
227            let lua_clone = self.lua_engine.clone();
228            let cb_clone = self.circuit_breaker.clone();
229            let am_clone = app_manager.clone();
230            let lb_clone = self.load_balancer.clone();
231
232            tokio::spawn(async move {
233                if let Err(e) = run_http_server(
234                    http_addr,
235                    config_clone,
236                    shutdown_clone,
237                    metrics_clone,
238                    challenge_store_clone,
239                    lua_clone,
240                    cb_clone,
241                    am_clone,
242                    lb_clone,
243                )
244                .await
245                {
246                    tracing::error!("HTTP/1.1 server error (listener {}): {}", i, e);
247                }
248            });
249        }
250
251        if let Some(https_addr) = https_addr {
252            for i in 0..num_listeners {
253                let config_clone = self.config.clone();
254                let shutdown_clone = self.shutdown.clone();
255                let acceptor = self.tls_acceptor.as_ref().unwrap().clone();
256                let metrics_clone = self.metrics.clone();
257                let challenge_store_clone = self.challenge_store.clone();
258                let lua_clone = self.lua_engine.clone();
259                let cb_clone = self.circuit_breaker.clone();
260                let am_clone = app_manager.clone();
261                let lb_clone = self.load_balancer.clone();
262
263                tokio::spawn(async move {
264                    if let Err(e) = run_https_server(
265                        https_addr,
266                        config_clone,
267                        shutdown_clone,
268                        acceptor,
269                        metrics_clone,
270                        challenge_store_clone,
271                        lua_clone,
272                        cb_clone,
273                        am_clone,
274                        lb_clone,
275                    )
276                    .await
277                    {
278                        tracing::error!("HTTPS/2 server error (listener {}): {}", i, e);
279                    }
280                });
281            }
282        }
283
284        tracing::info!(
285            "HTTP/1.1 server listening on {} ({} accept loops)",
286            http_addr,
287            num_listeners
288        );
289        if has_https {
290            tracing::info!(
291                "HTTPS/2 server listening on {} ({} accept loops)",
292                https_addr.unwrap(),
293                num_listeners
294            );
295        }
296
297        loop {
298            if self.shutdown.is_shutting_down() {
299                tracing::info!("Shutting down servers...");
300                break;
301            }
302            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
303        }
304
305        Ok(())
306    }
307}
308
309#[allow(clippy::too_many_arguments)]
310async fn run_http_server(
311    addr: SocketAddr,
312    config: Arc<ConfigManager>,
313    shutdown: ShutdownCoordinator,
314    metrics: SharedMetrics,
315    challenge_store: ChallengeStore,
316    lua_engine: OptionalLuaEngine,
317    circuit_breaker: SharedCircuitBreaker,
318    app_manager: Option<Arc<AppManager>>,
319    load_balancer: Arc<LoadBalancerState>,
320) -> Result<()> {
321    let listener = create_listener(addr)?;
322    let client = ConnectionPool::new().client();
323
324    loop {
325        if shutdown.is_shutting_down() {
326            break;
327        }
328
329        match listener.accept().await {
330            Ok((stream, _)) => {
331                let _ = stream.set_nodelay(true);
332                let client = client.clone();
333                let config = config.clone();
334                let metrics = metrics.clone();
335                let cs = challenge_store.clone();
336                let lua = lua_engine.clone();
337                let cb = circuit_breaker.clone();
338                let am = app_manager.clone();
339                let lb = load_balancer.clone();
340                tokio::spawn(async move {
341                    if let Err(e) = handle_http11_connection(
342                        stream, client, config, metrics, cs, lua, cb, am, lb,
343                    )
344                    .await
345                    {
346                        tracing::debug!("HTTP/1.1 connection error: {}", e);
347                    }
348                });
349            }
350            Err(e) => {
351                tracing::error!("HTTP/1.1 accept error: {}", e);
352            }
353        }
354    }
355
356    Ok(())
357}
358
359#[allow(clippy::too_many_arguments)]
360async fn run_https_server(
361    addr: SocketAddr,
362    config: Arc<ConfigManager>,
363    shutdown: ShutdownCoordinator,
364    acceptor: TlsAcceptor,
365    metrics: SharedMetrics,
366    challenge_store: ChallengeStore,
367    lua_engine: OptionalLuaEngine,
368    circuit_breaker: SharedCircuitBreaker,
369    app_manager: Option<Arc<AppManager>>,
370    load_balancer: Arc<LoadBalancerState>,
371) -> Result<()> {
372    let listener = create_listener(addr)?;
373    let client = ConnectionPool::new().client();
374
375    loop {
376        if shutdown.is_shutting_down() {
377            break;
378        }
379
380        match listener.accept().await {
381            Ok((stream, _)) => {
382                let _ = stream.set_nodelay(true);
383                let client = client.clone();
384                let config = config.clone();
385                let acceptor = acceptor.clone();
386                let metrics = metrics.clone();
387                let cs = challenge_store.clone();
388                let lua = lua_engine.clone();
389                let cb = circuit_breaker.clone();
390                let am = app_manager.clone();
391                let lb = load_balancer.clone();
392                tokio::spawn(async move {
393                    match acceptor.accept(stream).await {
394                        Ok(tls_stream) => {
395                            metrics.inc_tls_connections();
396                            if let Err(e) = handle_https2_connection(
397                                tls_stream, client, config, metrics, cs, lua, cb, am, lb,
398                            )
399                            .await
400                            {
401                                tracing::debug!("HTTPS/2 connection error: {}", e);
402                            }
403                        }
404                        Err(e) => {
405                            tracing::error!("TLS accept error: {}", e);
406                        }
407                    }
408                });
409            }
410            Err(e) => {
411                tracing::error!("HTTPS/2 accept error: {}", e);
412            }
413        }
414    }
415
416    Ok(())
417}
418
419#[allow(clippy::too_many_arguments)]
420async fn handle_http11_connection(
421    stream: tokio::net::TcpStream,
422    client: ProxyClient,
423    config: Arc<ConfigManager>,
424    metrics: SharedMetrics,
425    challenge_store: ChallengeStore,
426    lua_engine: OptionalLuaEngine,
427    circuit_breaker: SharedCircuitBreaker,
428    app_manager: Option<Arc<AppManager>>,
429    load_balancer: Arc<LoadBalancerState>,
430) -> Result<()> {
431    let io = TokioIo::new(stream);
432    let svc = service_fn(move |req| {
433        handle_request(
434            req,
435            client.clone(),
436            config.clone(),
437            metrics.clone(),
438            challenge_store.clone(),
439            lua_engine.clone(),
440            circuit_breaker.clone(),
441            app_manager.clone(),
442            load_balancer.clone(),
443        )
444    });
445
446    let conn = hyper::server::conn::http1::Builder::new()
447        .keep_alive(true)
448        .pipeline_flush(true)
449        .serve_connection(io, svc)
450        .with_upgrades();
451
452    if let Err(e) = conn.await {
453        tracing::debug!("HTTP/1.1 connection error: {}", e);
454    }
455
456    Ok(())
457}
458
459#[allow(clippy::too_many_arguments)]
460async fn handle_https2_connection(
461    stream: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
462    client: ProxyClient,
463    config: Arc<ConfigManager>,
464    metrics: SharedMetrics,
465    challenge_store: ChallengeStore,
466    lua_engine: OptionalLuaEngine,
467    circuit_breaker: SharedCircuitBreaker,
468    app_manager: Option<Arc<AppManager>>,
469    load_balancer: Arc<LoadBalancerState>,
470) -> Result<()> {
471    let is_h2 = stream.get_ref().1.alpn_protocol() == Some(b"h2");
472
473    let io = TokioIo::new(stream);
474
475    if is_h2 {
476        let exec = TokioExecutor::new();
477        let svc = service_fn(move |req| {
478            handle_request(
479                req,
480                client.clone(),
481                config.clone(),
482                metrics.clone(),
483                challenge_store.clone(),
484                lua_engine.clone(),
485                circuit_breaker.clone(),
486                app_manager.clone(),
487                load_balancer.clone(),
488            )
489        });
490        let conn = hyper::server::conn::http2::Builder::new(exec)
491            .initial_stream_window_size(1024 * 1024)
492            .initial_connection_window_size(2 * 1024 * 1024)
493            .max_concurrent_streams(250)
494            .serve_connection(io, svc);
495        if let Err(e) = conn.await {
496            tracing::debug!("HTTPS/2 connection error: {}", e);
497        }
498    } else {
499        let svc = service_fn(move |req| {
500            handle_request(
501                req,
502                client.clone(),
503                config.clone(),
504                metrics.clone(),
505                challenge_store.clone(),
506                lua_engine.clone(),
507                circuit_breaker.clone(),
508                app_manager.clone(),
509                load_balancer.clone(),
510            )
511        });
512        let conn = hyper::server::conn::http1::Builder::new()
513            .keep_alive(true)
514            .pipeline_flush(true)
515            .serve_connection(io, svc)
516            .with_upgrades();
517        if let Err(e) = conn.await {
518            tracing::debug!("HTTPS/1.1 connection error: {}", e);
519        }
520    }
521
522    Ok(())
523}
524
525/// Extract headers from a hyper request into a HashMap for Lua consumption.
526#[cfg(feature = "scripting")]
527fn extract_headers(req: &Request<Incoming>) -> std::collections::HashMap<String, String> {
528    req.headers()
529        .iter()
530        .map(|(k, v)| {
531            (
532                k.as_str().to_lowercase(),
533                v.to_str().unwrap_or("").to_string(),
534            )
535        })
536        .collect()
537}
538
539/// Build a LuaRequest from a hyper Request.
540#[cfg(feature = "scripting")]
541fn build_lua_request(req: &Request<Incoming>) -> LuaRequest {
542    let host = req
543        .uri()
544        .host()
545        .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
546        .unwrap_or("")
547        .to_string();
548
549    let content_length = req
550        .headers()
551        .get("content-length")
552        .and_then(|v| v.to_str().ok())
553        .and_then(|v| v.parse().ok())
554        .unwrap_or(0);
555
556    LuaRequest {
557        method: req.method().to_string(),
558        path: req.uri().path().to_string(),
559        headers: extract_headers(req),
560        host,
561        content_length,
562    }
563}
564
565/// Extract response headers into a HashMap for Lua consumption.
566#[cfg(feature = "scripting")]
567fn extract_response_headers(
568    headers: &hyper::HeaderMap,
569) -> std::collections::HashMap<String, String> {
570    headers
571        .iter()
572        .map(|(k, v)| {
573            (
574                k.as_str().to_lowercase(),
575                v.to_str().unwrap_or("").to_string(),
576            )
577        })
578        .collect()
579}
580
581#[allow(clippy::too_many_arguments)]
582async fn handle_request(
583    req: Request<Incoming>,
584    client: ProxyClient,
585    config_manager: Arc<ConfigManager>,
586    metrics: SharedMetrics,
587    challenge_store: ChallengeStore,
588    lua_engine: OptionalLuaEngine,
589    circuit_breaker: SharedCircuitBreaker,
590    app_manager: Option<Arc<AppManager>>,
591    load_balancer: Arc<LoadBalancerState>,
592) -> Result<Response<BoxBody>, hyper::Error> {
593    let start_time = std::time::Instant::now();
594    metrics.inc_in_flight();
595    let config = config_manager.get_config();
596
597    // ACME challenge check — must come before all other routing
598    if let Some(response) = handle_acme_challenge(&req, &challenge_store) {
599        metrics.dec_in_flight();
600        return Ok(response);
601    }
602
603    if is_metrics_request(&req) {
604        let duration = start_time.elapsed();
605        metrics.dec_in_flight();
606        let metrics_output = metrics.format_metrics();
607        metrics.record_request(0, metrics_output.len() as u64, 200, duration);
608        let body = http_body_util::Full::new(Bytes::from(metrics_output)).boxed();
609        return Ok(Response::builder()
610            .status(200)
611            .header("Content-Type", "text/plain")
612            .body(body)
613            .unwrap());
614    }
615
616    // --- Lua on_request hook ---
617    #[cfg(feature = "scripting")]
618    if let Some(ref engine) = lua_engine {
619        if engine.has_on_request() {
620            let mut lua_req = build_lua_request(&req);
621            match engine.call_on_request(&mut lua_req) {
622                RequestHookResult::Deny { status, body } => {
623                    metrics.dec_in_flight();
624                    let duration = start_time.elapsed();
625                    metrics.record_request(0, body.len() as u64, status, duration);
626                    let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
627                    return Ok(Response::builder().status(status).body(resp_body).unwrap());
628                }
629                RequestHookResult::Continue(updated_req) => {
630                    // Apply any header modifications back to the hyper request
631                    // We can't easily mutate the incoming request headers here since
632                    // we'd need to own it, so we store the lua_req for later use.
633                    // Headers set via set_header in on_request will be applied after
634                    // the request is decomposed into parts.
635                    let _ = updated_req;
636                }
637            }
638        }
639    }
640
641    let is_websocket = is_websocket_request(&req);
642
643    if is_websocket {
644        return handle_websocket_request(
645            req,
646            client,
647            &config,
648            &metrics,
649            start_time,
650            app_manager.clone(),
651        )
652        .await;
653    }
654
655    let result = handle_regular_request(
656        req,
657        client,
658        &config,
659        &lua_engine,
660        &circuit_breaker,
661        app_manager.clone(),
662        load_balancer.clone(),
663    )
664    .await;
665    let duration = start_time.elapsed();
666
667    metrics.dec_in_flight();
668
669    match result {
670        #[allow(unused_variables)]
671        Ok((response, _target_url, route_scripts)) => {
672            let status = response.status().as_u16();
673
674            // --- Lua on_request_end hooks (global + route) ---
675            #[cfg(feature = "scripting")]
676            if let Some(ref engine) = lua_engine {
677                let lua_req = LuaRequest {
678                    method: String::new(),
679                    path: String::new(),
680                    headers: std::collections::HashMap::new(),
681                    host: String::new(),
682                    content_length: 0,
683                };
684                let duration_ms = duration.as_secs_f64() * 1000.0;
685
686                // Global on_request_end
687                if engine.has_on_request_end() {
688                    engine.call_on_request_end(&lua_req, status, duration_ms, &_target_url);
689                }
690
691                // Route-specific on_request_end
692                for script_name in &route_scripts {
693                    engine.call_route_on_request_end(
694                        script_name,
695                        &lua_req,
696                        status,
697                        duration_ms,
698                        &_target_url,
699                    );
700                }
701            }
702
703            metrics.record_request(0, 0, status, duration);
704            record_app_metrics(&metrics, &app_manager, &_target_url, 0, 0, status, duration).await;
705            let (parts, body) = response.into_parts();
706            let boxed = body.map_err(|_| unreachable!()).boxed();
707            Ok(Response::from_parts(parts, boxed))
708        }
709        Err(e) => {
710            metrics.inc_errors();
711            Err(e)
712        }
713    }
714}
715
716fn is_websocket_request(req: &Request<Incoming>) -> bool {
717    if let Some(upgrade) = req.headers().get("upgrade") {
718        if upgrade == "websocket" {
719            return true;
720        }
721    }
722    false
723}
724
725fn is_metrics_request(req: &Request<Incoming>) -> bool {
726    req.uri().path() == "/metrics"
727}
728
729fn handle_acme_challenge(
730    req: &Request<Incoming>,
731    challenge_store: &ChallengeStore,
732) -> Option<Response<BoxBody>> {
733    let path = req.uri().path();
734    let prefix = "/.well-known/acme-challenge/";
735
736    if !path.starts_with(prefix) {
737        return None;
738    }
739
740    let token = &path[prefix.len()..];
741
742    if let Ok(store) = challenge_store.read() {
743        if let Some(key_auth) = store.get(token) {
744            let body = http_body_util::Full::new(Bytes::from(key_auth.clone())).boxed();
745            return Some(
746                Response::builder()
747                    .status(200)
748                    .header("Content-Type", "text/plain")
749                    .body(body)
750                    .unwrap(),
751            );
752        }
753    }
754
755    let body = http_body_util::Full::new(Bytes::from("Challenge not found")).boxed();
756    Some(Response::builder().status(404).body(body).unwrap())
757}
758
759async fn handle_websocket_request(
760    req: Request<Incoming>,
761    _client: ProxyClient,
762    config: &crate::config::Config,
763    metrics: &SharedMetrics,
764    _start_time: std::time::Instant,
765    _app_manager: Option<Arc<AppManager>>,
766) -> Result<Response<BoxBody>, hyper::Error> {
767    let target_result = find_target(&req, &config.rules);
768
769    if target_result.is_none() {
770        metrics.inc_errors();
771        let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
772        return Ok(Response::builder().status(421).body(body).unwrap());
773    }
774
775    let (target_url, _, _, _) = target_result.unwrap();
776
777    // Extract host:port from target URL (e.g. "http://127.0.0.1:3000/path" -> "127.0.0.1:3000")
778    let backend_addr = match url::Url::parse(&target_url) {
779        Ok(u) => format!(
780            "{}:{}",
781            u.host_str().unwrap_or("127.0.0.1"),
782            u.port().unwrap_or(80)
783        ),
784        Err(_) => {
785            metrics.inc_errors();
786            let body = http_body_util::Full::new(Bytes::from("Bad backend URL")).boxed();
787            return Ok(Response::builder().status(502).body(body).unwrap());
788        }
789    };
790
791    let path = req.uri().path().to_string();
792    let query = req
793        .uri()
794        .query()
795        .map(|q| format!("?{}", q))
796        .unwrap_or_default();
797
798    let ws_key = req
799        .headers()
800        .get("sec-websocket-key")
801        .and_then(|v| v.to_str().ok())
802        .unwrap_or("")
803        .to_string();
804    let ws_version = req
805        .headers()
806        .get("sec-websocket-version")
807        .and_then(|v| v.to_str().ok())
808        .unwrap_or("13")
809        .to_string();
810    let ws_protocol = req
811        .headers()
812        .get("sec-websocket-protocol")
813        .and_then(|v| v.to_str().ok())
814        .map(|s| s.to_string());
815    let host_header = req
816        .headers()
817        .get("host")
818        .and_then(|v| v.to_str().ok())
819        .unwrap_or(&backend_addr)
820        .to_string();
821
822    tracing::info!(
823        "WebSocket upgrade request to {}{}{}",
824        backend_addr,
825        path,
826        query
827    );
828
829    // Connect to the backend
830    let backend = match TcpStream::connect(&backend_addr).await {
831        Ok(s) => s,
832        Err(e) => {
833            tracing::error!("Failed to connect to backend for WebSocket: {}", e);
834            metrics.inc_errors();
835            let body = http_body_util::Full::new(Bytes::from("Backend not reachable")).boxed();
836            return Ok(Response::builder().status(502).body(body).unwrap());
837        }
838    };
839
840    // Send the upgrade request to the backend
841    let mut handshake = format!(
842        "GET {}{} HTTP/1.1\r\n\
843         Host: {}\r\n\
844         Upgrade: websocket\r\n\
845         Connection: Upgrade\r\n\
846         Sec-WebSocket-Key: {}\r\n\
847         Sec-WebSocket-Version: {}\r\n",
848        path, query, host_header, ws_key, ws_version,
849    );
850    if let Some(proto) = &ws_protocol {
851        handshake.push_str(&format!("Sec-WebSocket-Protocol: {}\r\n", proto));
852    }
853    handshake.push_str("\r\n");
854
855    let (mut backend_read, mut backend_write) = backend.into_split();
856    if let Err(e) = backend_write.write_all(handshake.as_bytes()).await {
857        tracing::error!("Failed to send WebSocket handshake to backend: {}", e);
858        metrics.inc_errors();
859        let body =
860            http_body_util::Full::new(Bytes::from("Failed to initiate WebSocket with backend"))
861                .boxed();
862        return Ok(Response::builder().status(502).body(body).unwrap());
863    }
864
865    // Read the backend's 101 response
866    let mut response_buf = vec![0u8; 4096];
867    let n = match tokio::io::AsyncReadExt::read(&mut backend_read, &mut response_buf).await {
868        Ok(n) if n > 0 => n,
869        _ => {
870            tracing::error!("No response from backend for WebSocket upgrade");
871            metrics.inc_errors();
872            let body = http_body_util::Full::new(Bytes::from(
873                "Backend did not respond to WebSocket upgrade",
874            ))
875            .boxed();
876            return Ok(Response::builder().status(502).body(body).unwrap());
877        }
878    };
879
880    let response_str = String::from_utf8_lossy(&response_buf[..n]);
881    if !response_str.contains("101") {
882        tracing::error!(
883            "Backend rejected WebSocket upgrade: {}",
884            response_str.lines().next().unwrap_or("")
885        );
886        metrics.inc_errors();
887        let body =
888            http_body_util::Full::new(Bytes::from("Backend rejected WebSocket upgrade")).boxed();
889        return Ok(Response::builder().status(502).body(body).unwrap());
890    }
891
892    // Extract headers from backend 101 response
893    let mut accept_key = String::new();
894    let mut resp_protocol = None;
895    for line in response_str.lines().skip(1) {
896        if line.trim().is_empty() {
897            break;
898        }
899        if let Some((name, value)) = line.split_once(':') {
900            let name_lower = name.trim().to_lowercase();
901            let value = value.trim().to_string();
902            if name_lower == "sec-websocket-accept" {
903                accept_key = value;
904            } else if name_lower == "sec-websocket-protocol" {
905                resp_protocol = Some(value);
906            }
907        }
908    }
909
910    // Use hyper::upgrade::on to get the client-side stream after we return 101
911    let client_upgrade = hyper::upgrade::on(req);
912
913    // Reunite the backend halves
914    let backend_stream = backend_read.reunite(backend_write).unwrap();
915
916    // Spawn the bidirectional copy task
917    tokio::spawn(async move {
918        match client_upgrade.await {
919            Ok(upgraded) => {
920                let mut client_stream = TokioIo::new(upgraded);
921                let (mut br, mut bw) = tokio::io::split(backend_stream);
922                let (mut cr, mut cw) = tokio::io::split(&mut client_stream);
923                let _ = tokio::join!(
924                    tokio::io::copy(&mut br, &mut cw),
925                    tokio::io::copy(&mut cr, &mut bw),
926                );
927            }
928            Err(e) => {
929                tracing::error!("WebSocket client upgrade failed: {}", e);
930            }
931        }
932    });
933
934    // Return 101 Switching Protocols to the client
935    let mut resp = Response::builder()
936        .status(101)
937        .header("Upgrade", "websocket")
938        .header("Connection", "Upgrade")
939        .header("Sec-WebSocket-Accept", accept_key);
940    if let Some(proto) = resp_protocol {
941        resp = resp.header("Sec-WebSocket-Protocol", proto);
942    }
943    Ok(resp
944        .body(http_body_util::Full::new(Bytes::new()).boxed())
945        .unwrap())
946}
947
948/// Returns (Response, target_url_for_logging, route_scripts)
949async fn handle_regular_request(
950    req: Request<Incoming>,
951    client: ProxyClient,
952    config: &crate::config::Config,
953    lua_engine: &OptionalLuaEngine,
954    circuit_breaker: &SharedCircuitBreaker,
955    _app_manager: Option<Arc<AppManager>>,
956    load_balancer: Arc<LoadBalancerState>,
957) -> Result<(Response<BoxBody>, String, Vec<String>), hyper::Error> {
958    let route = find_matching_rule(&req, &config.rules);
959
960    match route {
961        #[allow(unused_mut, unused_variables)]
962        Some(matched_route) => {
963            let path = req.uri().path().to_string();
964            let from_domain_rule = matched_route.from_domain_rule;
965            let matched_prefix = matched_route.matched_prefix();
966
967            if !matched_route.auth.is_empty() && !verify_basic_auth(&req, &matched_route.auth) {
968                tracing::debug!("Basic auth failed for {}", req.uri().path());
969                return Ok((create_auth_required_response(), String::new(), vec![]));
970            }
971            let route_scripts = matched_route.route_scripts.clone();
972
973            // Select an available target via circuit breaker
974            let target_selection =
975                select_target(&matched_route, &path, circuit_breaker, &load_balancer);
976            let (mut target_url, base_url) = match target_selection {
977                Some((url, base)) => (url, base),
978                None => {
979                    // All targets are circuit-broken
980                    let body =
981                        http_body_util::Full::new(Bytes::from("Service Unavailable")).boxed();
982                    return Ok((
983                        Response::builder()
984                            .status(503)
985                            .body(body)
986                            .expect("Failed to build response"),
987                        String::new(),
988                        route_scripts,
989                    ));
990                }
991            };
992            // --- Lua route-specific on_request hooks ---
993            #[cfg(feature = "scripting")]
994            if let Some(ref engine) = lua_engine {
995                for script_name in &route_scripts {
996                    let mut lua_req = build_lua_request(&req);
997                    match engine.call_route_on_request(script_name, &mut lua_req) {
998                        RequestHookResult::Deny { status, body } => {
999                            let resp_body = http_body_util::Full::new(Bytes::from(body)).boxed();
1000                            return Ok((
1001                                Response::builder().status(status).body(resp_body).unwrap(),
1002                                target_url,
1003                                route_scripts.clone(),
1004                            ));
1005                        }
1006                        RequestHookResult::Continue(_) => {}
1007                    }
1008                }
1009            }
1010
1011            // --- Lua on_route hook (global) ---
1012            #[cfg(feature = "scripting")]
1013            if let Some(ref engine) = lua_engine {
1014                if engine.has_on_route() {
1015                    let lua_req = build_lua_request(&req);
1016                    match engine.call_on_route(&lua_req, &target_url) {
1017                        RouteHookResult::Override(new_url) => {
1018                            target_url = new_url;
1019                        }
1020                        RouteHookResult::Default => {}
1021                    }
1022                }
1023                // Route-specific on_route hooks
1024                for script_name in &route_scripts {
1025                    let lua_req = build_lua_request(&req);
1026                    match engine.call_route_on_route(script_name, &lua_req, &target_url) {
1027                        RouteHookResult::Override(new_url) => {
1028                            target_url = new_url;
1029                        }
1030                        RouteHookResult::Default => {}
1031                    }
1032                }
1033            }
1034
1035            // Only extract host_header when needed (domain rules only)
1036            let host_header = if from_domain_rule {
1037                req.uri()
1038                    .host()
1039                    .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1040                    .map(|s| s.to_string())
1041            } else {
1042                None
1043            };
1044
1045            let (mut parts, body) = req.into_parts();
1046
1047            // Move headers directly instead of cloning one by one
1048            let uri: hyper::Uri = target_url.parse().expect("valid URI");
1049            parts.uri = uri;
1050            parts.version = http::Version::HTTP_11;
1051            parts.extensions = http::Extensions::new();
1052
1053            let mut request = Request::from_parts(parts, body);
1054
1055            request
1056                .headers_mut()
1057                .insert("X-Forwarded-For", X_FORWARDED_FOR_VALUE.clone());
1058
1059            if from_domain_rule {
1060                if let Some(host) = host_header {
1061                    request
1062                        .headers_mut()
1063                        .insert("X-Forwarded-Host", host.parse().unwrap());
1064                }
1065            }
1066
1067            match client.request(request).await {
1068                Ok(response) => {
1069                    // --- Circuit breaker: record success or failure ---
1070                    let status_code = response.status().as_u16();
1071                    if circuit_breaker.is_failure_status(status_code) {
1072                        circuit_breaker.record_failure(&base_url);
1073                    } else {
1074                        circuit_breaker.record_success(&base_url);
1075                    }
1076
1077                    // --- Lua on_response hooks (global + route) ---
1078                    #[cfg(feature = "scripting")]
1079                    if let Some(ref engine) = lua_engine {
1080                        let has_global = engine.has_on_response();
1081                        let has_route = !route_scripts.is_empty();
1082
1083                        if has_global || has_route {
1084                            use crate::scripting::ResponseMod;
1085
1086                            let lua_req = LuaRequest {
1087                                method: String::new(),
1088                                path: String::new(),
1089                                headers: std::collections::HashMap::new(),
1090                                host: String::new(),
1091                                content_length: 0,
1092                            };
1093                            let resp_headers = extract_response_headers(response.headers());
1094                            let resp_status = response.status().as_u16();
1095
1096                            // Collect all mods: global first, then route scripts
1097                            let mut all_mods: Vec<ResponseMod> = Vec::new();
1098                            if has_global {
1099                                all_mods.push(engine.call_on_response(
1100                                    &lua_req,
1101                                    resp_status,
1102                                    &resp_headers,
1103                                ));
1104                            }
1105                            for script_name in &route_scripts {
1106                                all_mods.push(engine.call_route_on_response(
1107                                    script_name,
1108                                    &lua_req,
1109                                    resp_status,
1110                                    &resp_headers,
1111                                ));
1112                            }
1113
1114                            // Merge all mods
1115                            let mut merged = ResponseMod::default();
1116                            for mods in all_mods {
1117                                merged.set_headers.extend(mods.set_headers);
1118                                merged.remove_headers.extend(mods.remove_headers);
1119                                if mods.replace_body.is_some() {
1120                                    merged.replace_body = mods.replace_body;
1121                                }
1122                                if mods.override_status.is_some() {
1123                                    merged.override_status = mods.override_status;
1124                                }
1125                            }
1126
1127                            // Apply modifications if any
1128                            if !merged.set_headers.is_empty()
1129                                || !merged.remove_headers.is_empty()
1130                                || merged.replace_body.is_some()
1131                                || merged.override_status.is_some()
1132                            {
1133                                let (mut parts, body) = response.into_parts();
1134
1135                                if let Some(status) = merged.override_status {
1136                                    parts.status =
1137                                        hyper::StatusCode::from_u16(status).unwrap_or(parts.status);
1138                                }
1139
1140                                for name in &merged.remove_headers {
1141                                    if let Ok(header_name) =
1142                                        name.parse::<hyper::header::HeaderName>()
1143                                    {
1144                                        parts.headers.remove(header_name);
1145                                    }
1146                                }
1147
1148                                for (name, value) in &merged.set_headers {
1149                                    if let (Ok(header_name), Ok(header_value)) = (
1150                                        name.parse::<hyper::header::HeaderName>(),
1151                                        value.parse::<HeaderValue>(),
1152                                    ) {
1153                                        parts.headers.insert(header_name, header_value);
1154                                    }
1155                                }
1156
1157                                if let Some(new_body) = merged.replace_body {
1158                                    let new_bytes = Bytes::from(new_body);
1159                                    parts.headers.remove("content-length");
1160                                    parts.headers.insert(
1161                                        "content-length",
1162                                        new_bytes.len().to_string().parse().unwrap(),
1163                                    );
1164                                    let boxed = http_body_util::Full::new(new_bytes).boxed();
1165                                    return Ok((
1166                                        Response::from_parts(parts, boxed),
1167                                        target_url,
1168                                        route_scripts.clone(),
1169                                    ));
1170                                }
1171
1172                                let boxed = body.map_err(|_| unreachable!()).boxed();
1173                                return Ok((
1174                                    Response::from_parts(parts, boxed),
1175                                    target_url,
1176                                    route_scripts.clone(),
1177                                ));
1178                            }
1179                        }
1180                    }
1181
1182                    let is_html = response
1183                        .headers()
1184                        .get("content-type")
1185                        .and_then(|v| v.to_str().ok())
1186                        .map(|ct| ct.starts_with("text/html"))
1187                        .unwrap_or(false);
1188
1189                    if is_html {
1190                        if let Some(prefix) = matched_prefix {
1191                            let (parts, body) = response.into_parts();
1192                            let body_bytes = body
1193                                .collect()
1194                                .await
1195                                .map(|collected| collected.to_bytes())
1196                                .unwrap_or_default();
1197
1198                            // Decompress gzip/deflate body before rewriting
1199                            let is_gzip = parts
1200                                .headers
1201                                .get("content-encoding")
1202                                .and_then(|v| v.to_str().ok())
1203                                .map(|v| v.contains("gzip"))
1204                                .unwrap_or(false);
1205                            let is_deflate = parts
1206                                .headers
1207                                .get("content-encoding")
1208                                .and_then(|v| v.to_str().ok())
1209                                .map(|v| v.contains("deflate"))
1210                                .unwrap_or(false);
1211
1212                            let raw_bytes = if is_gzip {
1213                                use std::io::Read;
1214                                let mut decoder = flate2::read::GzDecoder::new(&body_bytes[..]);
1215                                let mut decoded = Vec::new();
1216                                decoder.read_to_end(&mut decoded).unwrap_or_default();
1217                                Bytes::from(decoded)
1218                            } else if is_deflate {
1219                                use std::io::Read;
1220                                let mut decoder =
1221                                    flate2::read::DeflateDecoder::new(&body_bytes[..]);
1222                                let mut decoded = Vec::new();
1223                                decoder.read_to_end(&mut decoded).unwrap_or_default();
1224                                Bytes::from(decoded)
1225                            } else {
1226                                body_bytes
1227                            };
1228
1229                            let html = String::from_utf8_lossy(&raw_bytes);
1230                            let rewritten = html
1231                                .replace("href=\"/", &format!("href=\"{}/", prefix))
1232                                .replace("src=\"/", &format!("src=\"{}/", prefix))
1233                                .replace("action=\"/", &format!("action=\"{}/", prefix));
1234                            let rewritten_bytes = Bytes::from(rewritten);
1235                            let mut parts = parts;
1236                            parts.headers.remove("content-encoding");
1237                            parts.headers.remove("content-length");
1238                            parts.headers.insert(
1239                                "content-length",
1240                                rewritten_bytes.len().to_string().parse().unwrap(),
1241                            );
1242                            let boxed = http_body_util::Full::new(rewritten_bytes).boxed();
1243                            return Ok((
1244                                Response::from_parts(parts, boxed),
1245                                target_url,
1246                                route_scripts.clone(),
1247                            ));
1248                        }
1249                    }
1250
1251                    let (parts, body) = response.into_parts();
1252                    let boxed = body.map_err(|_| unreachable!()).boxed();
1253                    Ok((
1254                        Response::from_parts(parts, boxed),
1255                        target_url,
1256                        route_scripts,
1257                    ))
1258                }
1259                Err(e) => {
1260                    circuit_breaker.record_failure(&base_url);
1261                    tracing::error!("Backend request failed: {} (target: {})", e, target_url);
1262                    let body = http_body_util::Full::new(Bytes::from("Bad Gateway")).boxed();
1263                    Ok((
1264                        Response::builder()
1265                            .status(502)
1266                            .body(body)
1267                            .expect("Failed to build response"),
1268                        target_url,
1269                        route_scripts,
1270                    ))
1271                }
1272            }
1273        }
1274        None => {
1275            // Suppress unused variable warning when scripting feature is disabled
1276            let _ = lua_engine;
1277            let body = http_body_util::Full::new(Bytes::from("Misdirected Request")).boxed();
1278            Ok((
1279                Response::builder()
1280                    .status(421)
1281                    .body(body)
1282                    .expect("Failed to build response"),
1283                String::new(),
1284                vec![],
1285            ))
1286        }
1287    }
1288}
1289
1290/// How the target URL is resolved from the matched route
1291enum UrlResolution {
1292    /// Domain, Default: append full request path
1293    AppendPath,
1294    /// DomainPath, Prefix: strip prefix, append suffix
1295    StripPrefix(String),
1296    /// Exact, Regex: use target URL as-is
1297    Identity,
1298}
1299
1300/// A matched routing rule with all the info needed to resolve a target URL
1301struct MatchedRoute<'a> {
1302    targets: &'a [crate::config::Target],
1303    from_domain_rule: bool,
1304    resolution: UrlResolution,
1305    route_scripts: Vec<String>,
1306    auth: Vec<crate::auth::BasicAuth>,
1307    load_balancing: &'a crate::config::LoadBalancingStrategy,
1308}
1309
1310impl<'a> MatchedRoute<'a> {
1311    fn matched_prefix(&self) -> Option<String> {
1312        match &self.resolution {
1313            UrlResolution::StripPrefix(prefix) => Some(prefix.trim_end_matches('/').to_string()),
1314            _ => None,
1315        }
1316    }
1317}
1318
1319/// Resolve a target URL based on the resolution strategy
1320fn resolve_target_url(
1321    target: &crate::config::Target,
1322    path: &str,
1323    resolution: &UrlResolution,
1324) -> String {
1325    let target_str = target.url.as_str();
1326    match resolution {
1327        UrlResolution::AppendPath => {
1328            if target_str.ends_with('/') {
1329                format!("{}{}", target_str, &path[1..])
1330            } else {
1331                format!("{}{}", target_str, path)
1332            }
1333        }
1334        UrlResolution::StripPrefix(prefix) => {
1335            let suffix = if path.len() >= prefix.len() {
1336                &path[prefix.len()..]
1337            } else {
1338                ""
1339            };
1340            format!("{}{}", target_str, suffix)
1341        }
1342        UrlResolution::Identity => target_str.to_owned(),
1343    }
1344}
1345
1346/// Pure routing: find which rule matches the request
1347fn find_matching_rule<'a>(
1348    req: &Request<Incoming>,
1349    rules: &'a [crate::config::ProxyRule],
1350) -> Option<MatchedRoute<'a>> {
1351    let host = req
1352        .uri()
1353        .host()
1354        .or(req.headers().get("host").and_then(|h| h.to_str().ok()))
1355        .map(|h| h.split(':').next().unwrap_or(h))?;
1356
1357    let path = req.uri().path();
1358
1359    for rule in rules {
1360        match &rule.matcher {
1361            crate::config::RuleMatcher::Domain(domain) => {
1362                if domain == host && !rule.targets.is_empty() {
1363                    return Some(MatchedRoute {
1364                        targets: &rule.targets,
1365                        from_domain_rule: true,
1366                        resolution: UrlResolution::AppendPath,
1367                        route_scripts: rule.scripts.clone(),
1368                        auth: rule.auth.clone(),
1369                        load_balancing: &rule.load_balancing,
1370                    });
1371                }
1372            }
1373            crate::config::RuleMatcher::DomainPath(domain, path_prefix) => {
1374                if domain == host && !rule.targets.is_empty() {
1375                    let matches = path.starts_with(path_prefix)
1376                        || (path_prefix.ends_with('/')
1377                            && path == path_prefix.trim_end_matches('/'));
1378                    if matches {
1379                        return Some(MatchedRoute {
1380                            targets: &rule.targets,
1381                            from_domain_rule: true,
1382                            resolution: UrlResolution::StripPrefix(path_prefix.clone()),
1383                            route_scripts: rule.scripts.clone(),
1384                            auth: rule.auth.clone(),
1385                            load_balancing: &rule.load_balancing,
1386                        });
1387                    }
1388                }
1389            }
1390            _ => {}
1391        }
1392    }
1393
1394    // Check specific rules (Exact, Prefix, Regex) before Default
1395    for rule in rules {
1396        match &rule.matcher {
1397            crate::config::RuleMatcher::Exact(exact) => {
1398                if path == exact && !rule.targets.is_empty() {
1399                    return Some(MatchedRoute {
1400                        targets: &rule.targets,
1401                        from_domain_rule: false,
1402                        resolution: UrlResolution::Identity,
1403                        route_scripts: rule.scripts.clone(),
1404                        auth: rule.auth.clone(),
1405                        load_balancing: &rule.load_balancing,
1406                    });
1407                }
1408            }
1409            crate::config::RuleMatcher::Prefix(prefix) => {
1410                if !rule.targets.is_empty() {
1411                    // Match /db against prefix /db/ (path without trailing slash)
1412                    let matches = path.starts_with(prefix)
1413                        || (prefix.ends_with('/') && path == prefix.trim_end_matches('/'));
1414                    if matches {
1415                        return Some(MatchedRoute {
1416                            targets: &rule.targets,
1417                            from_domain_rule: false,
1418                            resolution: UrlResolution::StripPrefix(prefix.clone()),
1419                            route_scripts: rule.scripts.clone(),
1420                            auth: rule.auth.clone(),
1421                            load_balancing: &rule.load_balancing,
1422                        });
1423                    }
1424                }
1425            }
1426            crate::config::RuleMatcher::Regex(ref rm) => {
1427                if rm.is_match(path) && !rule.targets.is_empty() {
1428                    return Some(MatchedRoute {
1429                        targets: &rule.targets,
1430                        from_domain_rule: false,
1431                        resolution: UrlResolution::Identity,
1432                        route_scripts: rule.scripts.clone(),
1433                        auth: rule.auth.clone(),
1434                        load_balancing: &rule.load_balancing,
1435                    });
1436                }
1437            }
1438            _ => {}
1439        }
1440    }
1441
1442    // Fall back to Default rule
1443    for rule in rules {
1444        if let crate::config::RuleMatcher::Default = &rule.matcher {
1445            if !rule.targets.is_empty() {
1446                return Some(MatchedRoute {
1447                    targets: &rule.targets,
1448                    from_domain_rule: false,
1449                    resolution: UrlResolution::Identity,
1450                    route_scripts: rule.scripts.clone(),
1451                    auth: rule.auth.clone(),
1452                    load_balancing: &rule.load_balancing,
1453                });
1454            }
1455        }
1456    }
1457
1458    None
1459}
1460
1461/// Select a target based on the load balancing strategy.
1462/// Returns (resolved_url, base_url) for logging and circuit breaker tracking.
1463fn select_target(
1464    route: &MatchedRoute<'_>,
1465    path: &str,
1466    circuit_breaker: &crate::circuit_breaker::CircuitBreaker,
1467    load_balancer: &LoadBalancerState,
1468) -> Option<(String, String)> {
1469    let targets = route.targets;
1470    if targets.is_empty() {
1471        return None;
1472    }
1473
1474    match route.load_balancing {
1475        crate::config::LoadBalancingStrategy::Failover => {
1476            // Failover: use first available target (circuit breaker aware)
1477            for target in targets {
1478                let base_url = target.url.as_str().to_owned();
1479                if circuit_breaker.is_available(&base_url) {
1480                    let resolved = resolve_target_url(target, path, &route.resolution);
1481                    return Some((resolved, base_url));
1482                }
1483            }
1484            None
1485        }
1486        crate::config::LoadBalancingStrategy::RoundRobin => {
1487            // Round-robin: cycle through all targets, skip unhealthy ones
1488            let num_targets = targets.len();
1489            let start_idx = load_balancer.counters[0].load(Ordering::Relaxed) % num_targets;
1490
1491            for i in 0..num_targets {
1492                let idx = (start_idx + i) % num_targets;
1493                let target = &targets[idx];
1494                let base_url = target.url.as_str().to_owned();
1495                if circuit_breaker.is_available(&base_url) {
1496                    load_balancer.counters[0].fetch_add(1, Ordering::Relaxed);
1497                    let resolved = resolve_target_url(target, path, &route.resolution);
1498                    return Some((resolved, base_url));
1499                }
1500            }
1501            None
1502        }
1503        crate::config::LoadBalancingStrategy::Weighted => {
1504            // Weighted: use weights to determine distribution, skip unhealthy
1505            let total_weight: u32 = targets.iter().map(|t| t.weight as u32).sum();
1506            if total_weight == 0 {
1507                return select_target(route, path, circuit_breaker, &LoadBalancerState::new(1));
1508            }
1509
1510            let start_idx =
1511                (load_balancer.counters[0].load(Ordering::Relaxed) % total_weight as usize) as u32;
1512            let mut cumulative = 0u32;
1513
1514            for target in targets.iter() {
1515                cumulative += target.weight as u32;
1516                let base_url = target.url.as_str().to_owned();
1517                if cumulative > start_idx && circuit_breaker.is_available(&base_url) {
1518                    load_balancer.counters[0].fetch_add(1, Ordering::Relaxed);
1519                    let resolved = resolve_target_url(target, path, &route.resolution);
1520                    return Some((resolved, base_url));
1521                }
1522            }
1523
1524            // Fallback: try any available target
1525            for target in targets {
1526                let base_url = target.url.as_str().to_owned();
1527                if circuit_breaker.is_available(&base_url) {
1528                    let resolved = resolve_target_url(target, path, &route.resolution);
1529                    return Some((resolved, base_url));
1530                }
1531            }
1532            None
1533        }
1534    }
1535}
1536
1537/// Backward-compatible wrapper: returns (target_url, from_domain_rule, matched_prefix, route_scripts)
1538fn find_target(
1539    req: &Request<Incoming>,
1540    rules: &[crate::config::ProxyRule],
1541) -> Option<(String, bool, Option<String>, Vec<String>)> {
1542    let route = find_matching_rule(req, rules)?;
1543    let path = req.uri().path();
1544    let target = route.targets.first()?;
1545    let resolved = resolve_target_url(target, path, &route.resolution);
1546    let matched_prefix = route.matched_prefix();
1547    Some((
1548        resolved,
1549        route.from_domain_rule,
1550        matched_prefix,
1551        route.route_scripts,
1552    ))
1553}
1554
1555#[cfg(test)]
1556mod tests {
1557    use super::*;
1558
1559    #[test]
1560    fn test_load_balancer_state_select_index() {
1561        let lb = LoadBalancerState::new(1);
1562
1563        // First call should return 0
1564        assert_eq!(lb.select_index(0, 3), 0);
1565        // Second call should return 1
1566        assert_eq!(lb.select_index(0, 3), 1);
1567        // Third call should return 2
1568        assert_eq!(lb.select_index(0, 3), 2);
1569        // Fourth call wraps around to 0
1570        assert_eq!(lb.select_index(0, 3), 0);
1571    }
1572
1573    #[test]
1574    fn test_load_balancer_state_zero_targets() {
1575        let lb = LoadBalancerState::new(1);
1576        assert_eq!(lb.select_index(0, 0), 0);
1577    }
1578}