1use std::net::SocketAddr;
4use std::sync::Arc;
5
6use arbiter_metrics::ArbiterMetrics;
7use hyper::server::conn::http1;
8use hyper::service::service_fn;
9use hyper_util::rt::TokioIo;
10use tokio::net::TcpListener;
11use tokio::signal;
12
13use crate::config::ProxyConfig;
14use crate::middleware::MiddlewareChain;
15use crate::proxy::{ProxyState, build_audit, handle_request};
16
17pub async fn run(config: ProxyConfig) -> anyhow::Result<()> {
19 let upstream_uri: hyper::Uri = config
21 .upstream
22 .url
23 .parse()
24 .map_err(|e| anyhow::anyhow!("invalid upstream URL '{}': {e}", config.upstream.url))?;
25 match upstream_uri.scheme_str() {
26 Some("http") | Some("https") => {}
27 Some(scheme) => {
28 anyhow::bail!(
29 "upstream URL scheme '{}' is not supported; use http or https",
30 scheme
31 );
32 }
33 None => {
34 anyhow::bail!(
35 "upstream URL '{}' has no scheme; use http:// or https://",
36 config.upstream.url
37 );
38 }
39 }
40
41 let addr: SocketAddr = format!(
42 "{}:{}",
43 config.server.listen_addr, config.server.listen_port
44 )
45 .parse()?;
46
47 let middleware = MiddlewareChain::from_config(&config.middleware);
48
49 let (audit_sink, redaction_config) = build_audit(&config.audit);
50 let metrics = Arc::new(
51 ArbiterMetrics::new().map_err(|e| anyhow::anyhow!("failed to create metrics: {e}"))?,
52 );
53
54 let state = Arc::new(ProxyState::new(
55 config.upstream.url.clone(),
56 middleware,
57 audit_sink,
58 redaction_config,
59 metrics,
60 config.server.max_body_bytes,
61 std::time::Duration::from_secs(config.server.upstream_timeout_secs),
62 ));
63
64 let listener = TcpListener::bind(addr).await?;
65 tracing::info!(%addr, upstream = %config.upstream.url, "proxy listening");
66
67 let header_read_timeout =
68 std::time::Duration::from_secs(config.server.header_read_timeout_secs);
69
70 let connection_semaphore = Arc::new(tokio::sync::Semaphore::new(config.server.max_connections));
72 tracing::info!(
73 max_connections = config.server.max_connections,
74 "connection limit configured"
75 );
76
77 let shutdown = shutdown_signal();
78 tokio::pin!(shutdown);
79
80 loop {
81 tokio::select! {
82 result = listener.accept() => {
83 let (stream, remote_addr) = result?;
84 let state = Arc::clone(&state);
85 tracing::debug!(%remote_addr, "accepted connection");
86
87 let sem = Arc::clone(&connection_semaphore);
88 tokio::spawn(async move {
89 let _permit = match sem.acquire().await {
91 Ok(permit) => permit,
92 Err(_) => {
93 tracing::error!("connection semaphore closed");
94 return;
95 }
96 };
97 let io = TokioIo::new(stream);
98 let svc = service_fn(move |req| {
99 let state = Arc::clone(&state);
100 handle_request(state, req)
101 });
102 if let Err(e) = http1::Builder::new()
103 .header_read_timeout(header_read_timeout)
104 .serve_connection(io, svc)
105 .await
106 {
107 tracing::error!(error = %e, %remote_addr, "connection error");
108 }
109 });
110 }
111 _ = &mut shutdown => {
112 tracing::info!("shutdown signal received, stopping");
113 break;
114 }
115 }
116 }
117
118 Ok(())
119}
120
121async fn shutdown_signal() {
123 let ctrl_c = async {
124 signal::ctrl_c()
125 .await
126 .expect("failed to install ctrl-c handler");
127 };
128
129 #[cfg(unix)]
130 let terminate = async {
131 signal::unix::signal(signal::unix::SignalKind::terminate())
132 .expect("failed to install SIGTERM handler")
133 .recv()
134 .await;
135 };
136
137 #[cfg(not(unix))]
138 let terminate = std::future::pending::<()>();
139
140 tokio::select! {
141 _ = ctrl_c => {}
142 _ = terminate => {}
143 }
144}