1use std::{
2 convert::Infallible, error::Error, future::Future, net::{IpAddr, SocketAddr}, sync::Arc, time::Instant
3};
4
5use arc_metrics::{helpers::{ActiveGauge, DurationIncMs, RegisterableMetric}, IntCounter, IntGauge};
6use hyper::{body::{Body, Incoming}, service::service_fn};
7use hyper_util::{
8 rt::{TokioExecutor, TokioIo},
9 server::conn::auto::Builder,
10};
11use tokio::{net::TcpListener, sync::Semaphore};
12
13#[cfg(feature = "metrics-server")]
14pub mod prom_metrics_server;
15
16pub use bytes;
18pub use http_body_util::{BodyExt, Full};
19pub use hyper::{self, Request, Response, body, StatusCode, header, Error as HyperError};
20
21pub trait HttpServerHandler: Sync + Send + 'static {
22 type Body: Body<Data: Send + Sync, Error: Into<Box<dyn Error + Send + Sync>>> + Send;
23
24 fn handle_request(
25 self: Arc<Self>,
26 source: IpAddr,
27 request: Request<Incoming>,
28 ) -> impl Future<Output = Response<Self::Body>> + Send;
29}
30
31pub struct HttpServer<H: HttpServerHandler> {
32 handler: Arc<H>,
33 settings: HttpServerSettings,
34 metrics: Arc<HttpServerMetrics>,
35}
36
37#[derive(Default)]
38pub struct HttpServerMetrics {
39 pub tcp_waiting: IntGauge,
40 pub tcp_sessions: IntGauge,
41
42 pub tcp_blocked_waiting_count: IntCounter,
43 pub tcp_blocked_waiting_duration_ms: IntCounter,
44
45 pub tcp_accepts: IntCounter,
46 pub tcp_duration_ms: IntCounter,
47
48 pub http_requests: IntCounter,
49 pub http_sessions: IntGauge,
50
51 pub tcp_accept_errors: IntCounter,
52 pub tcp_accept_errors_too_many_files: IntCounter,
53 pub true_ip_parse_errors: IntCounter,
54 pub http_serve_errors: IntCounter,
55 #[cfg(feature = "tls")]
56 pub tls_accept_errors: IntCounter,
57}
58
59impl RegisterableMetric for HttpServerMetrics {
60 fn register(&'static self, register: &mut arc_metrics::RegisterAction) {
61 register.gauge("connections", &self.tcp_waiting)
62 .attr("status", "waiting");
63 register.gauge("connections", &self.tcp_sessions)
64 .attr("status", "active");
65
66 register.count("blocked_waiting_count", &self.tcp_blocked_waiting_count);
67 register.count("blocked_waiting_duration_ms", &self.tcp_blocked_waiting_duration_ms);
68
69 register.count("tcp_count", &self.tcp_accepts);
70 register.count("tcp_duration_ms", &self.tcp_duration_ms);
71
72 register.count("http_request_count", &self.http_requests);
73 register.gauge("http_sessions", &self.http_sessions);
74
75 register.count("accept_error", &self.tcp_accept_errors_too_many_files).attr("reason", "too_many_files");
76 register.count("accept_error", &self.tcp_accept_errors).attr("reason", "other");
77
78 register.count("errors", &self.true_ip_parse_errors).attr("type", "true_ip_parse");
79 register.count("errors", &self.http_serve_errors).attr("type", "http_serve");
80 #[cfg(feature = "tls")]
81 register.count("errors", &self.tls_accept_errors).attr("type", "tls_accept");
82 }
83}
84
85pub struct HttpServerSettings {
86 pub max_parallel: Option<usize>,
87 pub true_ip_header: Option<String>,
88 pub keep_alive: bool,
89 pub with_upgrades: bool,
90 #[cfg(feature = "tls")]
91 pub tls: Option<HttpTls>,
92}
93
94#[cfg(feature = "tls")]
95pub enum HttpTls {
96 WithBytes { cert: Vec<u8>, key: Vec<u8> },
97 WithPemPath { path: String },
98}
99
100impl Default for HttpServerSettings {
101 fn default() -> Self {
102 Self {
103 max_parallel: Some(200),
104 true_ip_header: None,
105 keep_alive: true,
106 with_upgrades: false,
107 #[cfg(feature = "tls")]
108 tls: None,
109 }
110 }
111}
112
113impl<H: HttpServerHandler> HttpServer<H> {
114 pub fn new(handler: Arc<H>, settings: HttpServerSettings) -> Self {
115 HttpServer {
116 handler,
117 settings,
118 metrics: Arc::new(HttpServerMetrics::default()),
119 }
120 }
121
122 pub fn get_metrics(&self) -> &Arc<HttpServerMetrics> {
123 &self.metrics
124 }
125
126 pub async fn start(self, listen_addr: SocketAddr) -> std::io::Result<()> {
127 #[cfg(feature = "tls")]
128 let tls = Arc::new(if let Some(tls) = &self.settings.tls {
129 tls_friend::install_crypto();
130
131 let acceptor = match tls {
132 HttpTls::WithBytes { cert, key } => tls_friend::tls_setup::TlsSetup::build_server(key, cert),
133 HttpTls::WithPemPath { path } => tls_friend::tls_setup::TlsSetup::load_server(&path).await,
134 }?.into_acceptor()?;
135
136 Some(acceptor)
137 } else { None });
138
139 tracing::info!(%listen_addr, "starting http server");
140
141 let metrics = self.metrics;
142 let tcp_listener = TcpListener::bind(listen_addr).await?;
143 let sem = self
144 .settings
145 .max_parallel
146 .map(|v| Arc::new(Semaphore::new(v)));
147
148 let true_ip_header = Arc::new(self.settings.true_ip_header);
149
150 loop {
151 let (stream, addr) = match tcp_listener.accept().await {
152 Ok(x) => x,
153 Err(error) => {
154 let counter = 'counter: {
155 #[cfg(target_family = "unix")]
156 {
157 if let Some(24) = error.raw_os_error() {
158 break 'counter &metrics.tcp_accept_errors_too_many_files;
159 }
160 }
161 &metrics.tcp_accept_errors
162 };
163 counter.inc();
164
165 tracing::error!(?error, "tcp failed to accept");
166 continue;
167 }
168 };
169
170 let sem = sem.clone();
171 let metrics = metrics.clone();
172 let handler = self.handler.clone();
173 let true_ip_header = true_ip_header.clone();
174
175 #[cfg(feature = "tls")]
176 let tls = tls.clone();
177
178 tokio::spawn(async move {
179 let _parallel_guard = 'block: {
180 let Some(sem) = sem.clone() else { break 'block None };
181
182 if let Ok(guard) = Arc::clone(&sem).try_acquire_owned() {
184 break 'block Some(guard);
185 }
186
187 metrics.tcp_blocked_waiting_count.inc();
188 let _waiting_count = ActiveGauge::new(&metrics, |m| &m.tcp_waiting);
189 let _waiting_duration = DurationIncMs::new(&metrics, |m| &m.tcp_blocked_waiting_duration_ms);
190 let guard = sem.acquire_owned().await.expect("Semaphore closed?");
191
192 Some(guard)
193 };
194
195 metrics.tcp_accepts.inc();
196
197 let _session_metric = ActiveGauge::new(&metrics, |m| &m.tcp_sessions);
198 let _duration_metric = DurationIncMs::new(&metrics, |m| &m.tcp_duration_ms);
199
200 let mut builder = Builder::new(TokioExecutor::new());
201 builder.http1().keep_alive(self.settings.keep_alive);
202
203 let handle = |req: Request<Incoming>| {
204 let handler = handler.clone();
205
206 metrics.http_requests.inc();
207 let _http_session_metric = ActiveGauge::new(&metrics, |m| &m.http_sessions);
208
209 let source_ip = if let Some(true_ip_header) = &*true_ip_header {
210 let true_ip_opt = req
211 .headers()
212 .get(true_ip_header)
213 .and_then(|ip| ip.to_str().ok())
214 .and_then(|ip| ip.parse::<IpAddr>().ok());
215
216 match true_ip_opt {
217 Some(v) => v,
218 None => {
219 metrics.true_ip_parse_errors.inc();
220 addr.ip()
221 }
222 }
223 } else {
224 addr.ip()
225 };
226
227 async move {
228 let res = handler.handle_request(source_ip, req).await;
229 Ok::<_, Infallible>(res)
230 }
231 };
232
233 #[cfg(feature = "tls")]
234 let stream = match &*tls {
235 Some(tls) => tls_friend::tls_streams::ServerStream::TlsStream({
236 match tls.accept(stream).await {
237 Ok(v) => v,
238 Err(error) => {
239 tracing::error!(?error, "failed to accept new tls stream");
240 metrics.tls_accept_errors.inc();
241 return;
242 }
243 }
244 }),
245 None => tls_friend::tls_streams::ServerStream::TcpStream(stream),
246 };
247
248 let result = if self.settings.with_upgrades {
249 builder.serve_connection_with_upgrades(TokioIo::new(stream), service_fn(handle)).await
250 } else {
251 builder.serve_connection(TokioIo::new(stream), service_fn(handle)).await
252 };
253
254 if let Err(e) = result {
255 tracing::error!(?e, %addr, "failed to serve request");
256 metrics.http_serve_errors.inc();
257 }
258 });
259 }
260 }
261}
262