http_app/
lib.rs

1use std::{
2    convert::Infallible, error::Error, future::Future, net::{IpAddr, SocketAddr}, sync::Arc, time::Instant
3};
4
5use arc_metrics::{helpers::{ActiveGauge, DurationIncMs, RegisterableMetric}, IntCounter, IntGauge};
6use hyper::{body::{Body, Incoming}, service::service_fn};
7use hyper_util::{
8    rt::{TokioExecutor, TokioIo},
9    server::conn::auto::Builder,
10};
11use tokio::{net::TcpListener, sync::Semaphore};
12
13#[cfg(feature = "metrics-server")]
14pub mod prom_metrics_server;
15
16/* re-export for downstream */
17pub use bytes;
18pub use http_body_util::{BodyExt, Full};
19pub use hyper::{self, Request, Response, body, StatusCode, header, Error as HyperError};
20
21pub trait HttpServerHandler: Sync + Send + 'static {
22    type Body: Body<Data: Send + Sync, Error: Into<Box<dyn Error + Send + Sync>>> + Send;
23
24    fn handle_request(
25        self: Arc<Self>,
26        source: IpAddr,
27        request: Request<Incoming>,
28    ) -> impl Future<Output = Response<Self::Body>> + Send;
29}
30
31pub struct HttpServer<H: HttpServerHandler> {
32    handler: Arc<H>,
33    settings: HttpServerSettings,
34    metrics: Arc<HttpServerMetrics>,
35}
36
37#[derive(Default)]
38pub struct HttpServerMetrics {
39    pub tcp_waiting: IntGauge,
40    pub tcp_sessions: IntGauge,
41
42    pub tcp_blocked_waiting_count: IntCounter,
43    pub tcp_blocked_waiting_duration_ms: IntCounter,
44
45    pub tcp_accepts: IntCounter,
46    pub tcp_duration_ms: IntCounter,
47
48    pub http_requests: IntCounter,
49    pub http_sessions: IntGauge,
50
51    pub tcp_accept_errors: IntCounter,
52    pub tcp_accept_errors_too_many_files: IntCounter,
53    pub true_ip_parse_errors: IntCounter,
54    pub http_serve_errors: IntCounter,
55    #[cfg(feature = "tls")]
56    pub tls_accept_errors: IntCounter,
57}
58
59impl RegisterableMetric for HttpServerMetrics {
60    fn register(&'static self, register: &mut arc_metrics::RegisterAction) {
61        register.gauge("connections", &self.tcp_waiting)
62            .attr("status", "waiting");
63        register.gauge("connections", &self.tcp_sessions)
64            .attr("status", "active");
65
66        register.count("blocked_waiting_count", &self.tcp_blocked_waiting_count);
67        register.count("blocked_waiting_duration_ms", &self.tcp_blocked_waiting_duration_ms);
68
69        register.count("tcp_count", &self.tcp_accepts);
70        register.count("tcp_duration_ms", &self.tcp_duration_ms);
71
72        register.count("http_request_count", &self.http_requests);
73        register.gauge("http_sessions", &self.http_sessions);
74
75        register.count("accept_error", &self.tcp_accept_errors_too_many_files).attr("reason", "too_many_files");
76        register.count("accept_error", &self.tcp_accept_errors).attr("reason", "other");
77
78        register.count("errors", &self.true_ip_parse_errors).attr("type", "true_ip_parse");
79        register.count("errors", &self.http_serve_errors).attr("type", "http_serve");
80        #[cfg(feature = "tls")]
81        register.count("errors", &self.tls_accept_errors).attr("type", "tls_accept");
82    }
83}
84
85pub struct HttpServerSettings {
86    pub max_parallel: Option<usize>,
87    pub true_ip_header: Option<String>,
88    pub keep_alive: bool,
89    pub with_upgrades: bool,
90    #[cfg(feature = "tls")]
91    pub tls: Option<HttpTls>,
92}
93
94#[cfg(feature = "tls")]
95pub enum HttpTls {
96    WithBytes { cert: Vec<u8>, key: Vec<u8> },
97    WithPemPath { path: String },
98}
99
100impl Default for HttpServerSettings {
101    fn default() -> Self {
102        Self {
103            max_parallel: Some(200),
104            true_ip_header: None,
105            keep_alive: true,
106            with_upgrades: false,
107            #[cfg(feature = "tls")]
108            tls: None,
109        }
110    }
111}
112
113impl<H: HttpServerHandler> HttpServer<H> {
114    pub fn new(handler: Arc<H>, settings: HttpServerSettings) -> Self {
115        HttpServer {
116            handler,
117            settings,
118            metrics: Arc::new(HttpServerMetrics::default()),
119        }
120    }
121
122    pub fn get_metrics(&self) -> &Arc<HttpServerMetrics> {
123        &self.metrics
124    }
125
126    pub async fn start(self, listen_addr: SocketAddr) -> std::io::Result<()> {
127        #[cfg(feature = "tls")]
128        let tls = Arc::new(if let Some(tls) = &self.settings.tls {
129            tls_friend::install_crypto();
130
131            let acceptor = match tls {
132                HttpTls::WithBytes { cert, key } => tls_friend::tls_setup::TlsSetup::build_server(key, cert),
133                HttpTls::WithPemPath { path } => tls_friend::tls_setup::TlsSetup::load_server(&path).await,
134            }?.into_acceptor()?;
135
136            Some(acceptor)
137        } else { None });
138
139        tracing::info!(%listen_addr, "starting http server");
140
141        let metrics = self.metrics;
142        let tcp_listener = TcpListener::bind(listen_addr).await?;
143        let sem = self
144            .settings
145            .max_parallel
146            .map(|v| Arc::new(Semaphore::new(v)));
147
148        let true_ip_header = Arc::new(self.settings.true_ip_header);
149
150        loop {
151            let (stream, addr) = match tcp_listener.accept().await {
152                Ok(x) => x,
153                Err(error) => {
154                    let counter = 'counter: {
155                        #[cfg(target_family = "unix")]
156                        {
157                            if let Some(24) = error.raw_os_error() {
158                                break 'counter &metrics.tcp_accept_errors_too_many_files;
159                            }
160                        }
161                        &metrics.tcp_accept_errors
162                    };
163                    counter.inc();
164
165                    tracing::error!(?error, "tcp failed to accept");
166                    continue;
167                }
168            };
169
170            let sem = sem.clone();
171            let metrics = metrics.clone();
172            let handler = self.handler.clone();
173            let true_ip_header = true_ip_header.clone();
174
175            #[cfg(feature = "tls")]
176            let tls = tls.clone();
177
178            tokio::spawn(async move {
179                let _parallel_guard = 'block: {
180                    let Some(sem) = sem.clone() else { break 'block None };
181
182                    /* try non blocking first so we only update metrics if we're block */
183                    if let Ok(guard) = Arc::clone(&sem).try_acquire_owned() {
184                        break 'block Some(guard);
185                    }
186
187                    metrics.tcp_blocked_waiting_count.inc();
188                    let _waiting_count = ActiveGauge::new(&metrics, |m| &m.tcp_waiting);
189                    let _waiting_duration = DurationIncMs::new(&metrics, |m| &m.tcp_blocked_waiting_duration_ms);
190                    let guard = sem.acquire_owned().await.expect("Semaphore closed?");
191
192                    Some(guard)
193                };
194
195                metrics.tcp_accepts.inc();
196
197                let _session_metric = ActiveGauge::new(&metrics, |m| &m.tcp_sessions);
198                let _duration_metric = DurationIncMs::new(&metrics, |m| &m.tcp_duration_ms);
199
200                let mut builder = Builder::new(TokioExecutor::new());
201                builder.http1().keep_alive(self.settings.keep_alive);
202
203                let handle = |req: Request<Incoming>| {
204                    let handler = handler.clone();
205
206                    metrics.http_requests.inc();
207                    let _http_session_metric = ActiveGauge::new(&metrics, |m| &m.http_sessions);
208
209                    let source_ip = if let Some(true_ip_header) = &*true_ip_header {
210                        let true_ip_opt = req
211                            .headers()
212                            .get(true_ip_header)
213                            .and_then(|ip| ip.to_str().ok())
214                            .and_then(|ip| ip.parse::<IpAddr>().ok());
215
216                        match true_ip_opt {
217                            Some(v) => v,
218                            None => {
219                                metrics.true_ip_parse_errors.inc();
220                                addr.ip()
221                            }
222                        }
223                    } else {
224                        addr.ip()
225                    };
226
227                    async move {
228                        let res = handler.handle_request(source_ip, req).await;
229                        Ok::<_, Infallible>(res)
230                    }
231                };
232
233                #[cfg(feature = "tls")]
234                let stream = match &*tls {
235                    Some(tls) => tls_friend::tls_streams::ServerStream::TlsStream({
236                        match tls.accept(stream).await {
237                            Ok(v) => v,
238                            Err(error) => {
239                                tracing::error!(?error, "failed to accept new tls stream");
240                                metrics.tls_accept_errors.inc();
241                                return;
242                            }
243                        }
244                    }),
245                    None => tls_friend::tls_streams::ServerStream::TcpStream(stream),
246                };
247
248                let result = if self.settings.with_upgrades {
249                    builder.serve_connection_with_upgrades(TokioIo::new(stream), service_fn(handle)).await
250                } else {
251                    builder.serve_connection(TokioIo::new(stream), service_fn(handle)).await
252                };
253
254                if let Err(e) = result {
255                    tracing::error!(?e, %addr, "failed to serve request");
256                    metrics.http_serve_errors.inc();
257                }
258            });
259        }
260    }
261}
262