Skip to main content

arbiter_proxy/
server.rs

1//! Server bootstrap and graceful shutdown.
2
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use arbiter_metrics::ArbiterMetrics;
7use hyper::server::conn::http1;
8use hyper::service::service_fn;
9use hyper_util::rt::TokioIo;
10use tokio::net::TcpListener;
11use tokio::signal;
12
13use crate::config::ProxyConfig;
14use crate::middleware::MiddlewareChain;
15use crate::proxy::{ProxyState, build_audit, handle_request};
16
17/// Run the proxy server. Blocks until a shutdown signal is received.
18pub async fn run(config: ProxyConfig) -> anyhow::Result<()> {
19    // Validate upstream URL at startup, not at first request.
20    let upstream_uri: hyper::Uri = config
21        .upstream
22        .url
23        .parse()
24        .map_err(|e| anyhow::anyhow!("invalid upstream URL '{}': {e}", config.upstream.url))?;
25    match upstream_uri.scheme_str() {
26        Some("http") | Some("https") => {}
27        Some(scheme) => {
28            anyhow::bail!(
29                "upstream URL scheme '{}' is not supported; use http or https",
30                scheme
31            );
32        }
33        None => {
34            anyhow::bail!(
35                "upstream URL '{}' has no scheme; use http:// or https://",
36                config.upstream.url
37            );
38        }
39    }
40
41    let addr: SocketAddr = format!(
42        "{}:{}",
43        config.server.listen_addr, config.server.listen_port
44    )
45    .parse()?;
46
47    let middleware = MiddlewareChain::from_config(&config.middleware);
48
49    let (audit_sink, redaction_config) = build_audit(&config.audit);
50    let metrics = Arc::new(
51        ArbiterMetrics::new().map_err(|e| anyhow::anyhow!("failed to create metrics: {e}"))?,
52    );
53
54    let state = Arc::new(ProxyState::new(
55        config.upstream.url.clone(),
56        middleware,
57        audit_sink,
58        redaction_config,
59        metrics,
60        config.server.max_body_bytes,
61        std::time::Duration::from_secs(config.server.upstream_timeout_secs),
62    ));
63
64    let listener = TcpListener::bind(addr).await?;
65    tracing::info!(%addr, upstream = %config.upstream.url, "proxy listening");
66
67    let header_read_timeout =
68        std::time::Duration::from_secs(config.server.header_read_timeout_secs);
69
70    // Connection concurrency limit to prevent resource exhaustion from connection floods.
71    let connection_semaphore = Arc::new(tokio::sync::Semaphore::new(config.server.max_connections));
72    tracing::info!(
73        max_connections = config.server.max_connections,
74        "connection limit configured"
75    );
76
77    let shutdown = shutdown_signal();
78    tokio::pin!(shutdown);
79
80    loop {
81        tokio::select! {
82            result = listener.accept() => {
83                let (stream, remote_addr) = result?;
84                let state = Arc::clone(&state);
85                tracing::debug!(%remote_addr, "accepted connection");
86
87                let sem = Arc::clone(&connection_semaphore);
88                tokio::spawn(async move {
89                    // Acquire a permit before serving; drop releases it.
90                    let _permit = match sem.acquire().await {
91                        Ok(permit) => permit,
92                        Err(_) => {
93                            tracing::error!("connection semaphore closed");
94                            return;
95                        }
96                    };
97                    let io = TokioIo::new(stream);
98                    let svc = service_fn(move |req| {
99                        let state = Arc::clone(&state);
100                        handle_request(state, req)
101                    });
102                    if let Err(e) = http1::Builder::new()
103                        .header_read_timeout(header_read_timeout)
104                        .serve_connection(io, svc)
105                        .await
106                    {
107                        tracing::error!(error = %e, %remote_addr, "connection error");
108                    }
109                });
110            }
111            _ = &mut shutdown => {
112                tracing::info!("shutdown signal received, stopping");
113                break;
114            }
115        }
116    }
117
118    Ok(())
119}
120
121/// Wait for SIGTERM or SIGINT (ctrl-c).
122async fn shutdown_signal() {
123    let ctrl_c = async {
124        signal::ctrl_c()
125            .await
126            .expect("failed to install ctrl-c handler");
127    };
128
129    #[cfg(unix)]
130    let terminate = async {
131        signal::unix::signal(signal::unix::SignalKind::terminate())
132            .expect("failed to install SIGTERM handler")
133            .recv()
134            .await;
135    };
136
137    #[cfg(not(unix))]
138    let terminate = std::future::pending::<()>();
139
140    tokio::select! {
141        _ = ctrl_c => {}
142        _ = terminate => {}
143    }
144}