Skip to main content

ic_bn_lib/http/server/
mod.rs

1pub mod proxy_protocol;
2
3use std::{
4    fmt::Display,
5    net::SocketAddr,
6    path::PathBuf,
7    sync::{
8        Arc,
9        atomic::{AtomicU32, AtomicU64, Ordering},
10    },
11    time::{Duration, Instant},
12};
13
14use anyhow::{Context, anyhow};
15use async_trait::async_trait;
16use axum::{Router, extract::Request};
17use http::Response;
18use hyper::body::Incoming;
19use hyper_util::{
20    rt::{TokioExecutor, TokioIo, TokioTimer},
21    server::conn::auto::Builder,
22};
23use ic_bn_lib_common::{
24    traits::Run,
25    types::{
26        http::{
27            ALPN_ACME, Addr, ConnInfo, Error, ListenerOpts, Metrics, ProxyProtocolMode,
28            ServerOptions, TlsInfo,
29        },
30        tls::TlsOptions,
31    },
32};
33use prometheus::{
34    Registry,
35    core::{AtomicI64, GenericGauge},
36};
37use proxy_protocol::{ProxyHeader, ProxyProtocolStream};
38use rustls::sign::SingleCertAndKey;
39use scopeguard::defer;
40use tokio::{
41    io::AsyncWriteExt,
42    pin, select,
43    sync::mpsc::channel,
44    time::{sleep, timeout},
45};
46use tokio_io_timeout::TimeoutStream;
47use tokio_util::{sync::CancellationToken, task::TaskTracker};
48use tower_service::Service;
49use tracing::{debug, info, warn};
50use uuid::Uuid;
51
52use super::body::NotifyingBody;
53use crate::{
54    network::{AsyncCounter, AsyncReadWrite, listener::Listener, tls_handshake},
55    tls::{pem_convert_to_rustls, prepare_server_config},
56};
57
58const YEAR: Duration = Duration::from_secs(86400 * 365);
59
60#[derive(Clone)]
61enum RequestState {
62    Start,
63    End,
64}
65
66struct Conn {
67    addr: Addr,
68    remote_addr: Addr,
69    router: Router,
70    builder: Builder<TokioExecutor>,
71    token_graceful: CancellationToken,
72    token_forceful: CancellationToken,
73    options: ServerOptions,
74    metrics: Metrics,
75    requests: AtomicU32,
76    rustls_cfg: Option<Arc<rustls::ServerConfig>>,
77}
78
79impl Display for Conn {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(f, "[{}] <- [{}]", self.addr, self.remote_addr)
82    }
83}
84
85impl Conn {
86    async fn handle(&self, stream: Box<dyn AsyncReadWrite>) -> Result<(), Error> {
87        let accepted_at = Instant::now();
88
89        debug!("{self}: got a new connection");
90
91        // Prepare metric labels
92        let addr = self.addr.to_string();
93        let labels = &mut [
94            addr.as_str(),             // Listening addr
95            self.remote_addr.family(), // Remote client address family
96            "no",                      // TLS version
97            "no",                      // TLS ciphersuite
98            "no",                      // Force-closed
99            "no",                      // Recycled
100        ];
101
102        // Wrap with traffic counter
103        let (stream, stats) = AsyncCounter::new(stream);
104
105        // Read & parse Proxy Protocol v2 header if configured
106        let (stream, proxy_hdr): (Box<dyn AsyncReadWrite>, Option<ProxyHeader>) =
107            if self.options.proxy_protocol_mode != ProxyProtocolMode::Off {
108                let (stream, hdr) = ProxyProtocolStream::accept(stream)
109                    .await
110                    .context("unable to accept Proxy Protocol")?;
111
112                if self.options.proxy_protocol_mode == ProxyProtocolMode::Forced && hdr.is_none() {
113                    return Err(Error::NoProxyProtocolDetected);
114                }
115
116                (Box::new(stream), hdr)
117            } else {
118                (Box::new(stream), None)
119            };
120
121        // Use IPs from Proxy Protocol if available
122        let (local_addr, remote_addr) = proxy_hdr
123            .map(|x| (Addr::Tcp(x.dst), Addr::Tcp(x.src)))
124            .unwrap_or_else(|| (self.addr.clone(), self.remote_addr.clone()));
125
126        let conn_info = Arc::new(ConnInfo {
127            id: Uuid::now_v7(),
128            accepted_at,
129            remote_addr,
130            local_addr,
131            traffic: stats.clone(),
132            req_count: AtomicU64::new(0),
133            close: self.token_forceful.clone(),
134        });
135
136        // Perform TLS handshake if we're in TLS mode
137        let (stream, tls_info): (Box<dyn AsyncReadWrite>, _) = if let Some(rustls_cfg) =
138            &self.rustls_cfg
139        {
140            debug!("{}: performing TLS handshake", self);
141
142            let (mut stream_tls, tls_info) = timeout(
143                self.options.tls_handshake_timeout,
144                tls_handshake(rustls_cfg.clone(), stream),
145            )
146            .await
147            .context("TLS handshake timed out")?
148            .context("TLS handshake failed")?;
149
150            debug!(
151                "{}: handshake finished in {}ms (SNI: {:?}, proto: {:?}, cipher: {:?}, ALPN: {:?})",
152                self,
153                tls_info.handshake_dur.as_millis(),
154                tls_info.sni,
155                tls_info.protocol,
156                tls_info.cipher,
157                tls_info.alpn,
158            );
159
160            // Close the connection if agreed ALPN is ACME - the handshake is enough for the challenge
161            if tls_info
162                .alpn
163                .as_ref()
164                .is_some_and(|x| x.as_bytes() == ALPN_ACME)
165            {
166                debug!("{self}: ACME ALPN - closing connection");
167
168                timeout(Duration::from_secs(5), stream_tls.shutdown())
169                    .await
170                    .context("socket shutdown timed out")?
171                    .context("socket shutdown failed")?;
172
173                return Ok(());
174            }
175
176            (Box::new(stream_tls), Some(Arc::new(tls_info)))
177        } else {
178            (Box::new(stream), None)
179        };
180
181        // Record TLS metrics
182        if let Some(v) = &tls_info {
183            labels[2] = v.protocol.as_str().unwrap();
184            labels[3] = v.cipher.as_str().unwrap();
185
186            self.metrics
187                .conn_tls_handshake_duration
188                .with_label_values(&labels[0..4])
189                .observe(v.handshake_dur.as_secs_f64());
190        }
191
192        self.metrics
193            .conns_open
194            .with_label_values(&labels[0..4])
195            .inc();
196
197        let requests_inflight = self
198            .metrics
199            .requests_inflight
200            .with_label_values(&labels[0..4]);
201
202        // Handle the connection
203        let result = self
204            .handle_inner(stream, conn_info.clone(), tls_info, requests_inflight)
205            .await;
206
207        // Record connection metrics
208        let (sent, rcvd) = (stats.sent(), stats.rcvd());
209        let dur = accepted_at.elapsed().as_secs_f64();
210        let reqs = conn_info.req_count.load(Ordering::SeqCst);
211
212        // force-closed
213        if self.token_forceful.is_cancelled() {
214            labels[4] = "yes";
215        }
216        // recycled
217        if self.token_graceful.is_cancelled() {
218            labels[5] = "yes";
219        }
220
221        self.metrics.conns.with_label_values(labels).inc();
222        self.metrics
223            .conns_open
224            .with_label_values(&labels[0..4])
225            .dec();
226        self.metrics.requests.with_label_values(labels).inc_by(reqs);
227        self.metrics
228            .bytes_rcvd
229            .with_label_values(labels)
230            .inc_by(rcvd);
231        self.metrics
232            .bytes_sent
233            .with_label_values(labels)
234            .inc_by(sent);
235        self.metrics
236            .conn_duration
237            .with_label_values(labels)
238            .observe(dur);
239        self.metrics
240            .requests_per_conn
241            .with_label_values(labels)
242            .observe(reqs as f64);
243
244        debug!(
245            "{self}: connection closed (rcvd: {rcvd}, sent: {sent}, reqs: {reqs}, duration: {dur}, graceful: {}, forced close: {})",
246            self.token_graceful.is_cancelled(),
247            self.token_forceful.is_cancelled(),
248        );
249
250        result
251    }
252
253    async fn handle_inner(
254        &self,
255        stream: Box<dyn AsyncReadWrite>,
256        conn_info: Arc<ConnInfo>,
257        tls_info: Option<Arc<TlsInfo>>,
258        requests_inflight: GenericGauge<AtomicI64>,
259    ) -> Result<(), Error> {
260        // Create a timer for idle connection tracking.
261        // Falls back to 10 years if idle timer is not set (for simplicity)
262        let mut idle_timer = Box::pin(sleep(self.options.idle_timeout.unwrap_or(10 * YEAR)));
263
264        // Create channel to notify about request start/stop.
265        // Use bounded but big enough so that it's larger than our concurrency.
266        let (state_tx, mut state_rx) = channel(65536);
267
268        // Apply timeouts on read/write calls
269        let mut stream = TimeoutStream::new(stream);
270        stream.set_read_timeout(self.options.read_timeout);
271        stream.set_write_timeout(self.options.write_timeout);
272
273        // Convert stream from Tokio to Hyper
274        let stream = TokioIo::new(stream);
275
276        // Convert router to Hyper service
277        let max_requests_per_conn = self.options.max_requests_per_conn;
278        let service = hyper::service::service_fn(move |mut request: Request<Incoming>| {
279            // Notify that we have started processing the request
280            let _ = state_tx.try_send(RequestState::Start);
281
282            // Inject connection information
283            request.extensions_mut().insert(conn_info.clone());
284            if let Some(v) = &tls_info {
285                request.extensions_mut().insert(v.clone());
286            }
287
288            // Clone the stuff needed in the async block below
289            let mut router = self.router.clone();
290            let token = self.token_graceful.clone();
291            let conn_info = conn_info.clone();
292            let state_tx = state_tx.clone();
293            let requests_inflight = requests_inflight.clone();
294
295            // Return the future
296            async move {
297                // Increase the global inflight requests counter
298                requests_inflight.inc();
299
300                // Since the future can be cancelled we need defer to decrease the counter in any case
301                // to avoid leaking the inflight requests
302                defer! {
303                    requests_inflight.dec();
304                }
305
306                // Execute the request
307                let result = router.call(request).await.map(|x| {
308                    // Wrap the response body into a notifying one
309                    let (parts, body) = x.into_parts();
310                    let body = NotifyingBody::new(body, state_tx, RequestState::End);
311                    Response::from_parts(parts, body)
312                });
313
314                // Check if we need to gracefully shutdown this connection
315                if let Some(v) = max_requests_per_conn {
316                    let req_count = conn_info.req_count.fetch_add(1, Ordering::SeqCst);
317                    if req_count + 1 >= v {
318                        token.cancel();
319                    }
320                }
321
322                result
323            }
324        });
325
326        // Serve the connection
327        let conn = self
328            .builder
329            .serve_connection_with_upgrades(Box::pin(stream), service);
330
331        // Using mutable future reference requires pinning
332        pin!(conn);
333
334        loop {
335            select! {
336                biased; // Poll top-down
337
338                // Immediately close the connection if was requested
339                () = self.token_forceful.cancelled() => {
340                    break;
341                }
342
343                // Start graceful shutdown of the connection
344                () = self.token_graceful.cancelled() => {
345                    // For H2: sends GOAWAY frames to the client
346                    // For H1: disables keepalives
347                    conn.as_mut().graceful_shutdown();
348
349                    // Wait for the grace period to finish or connection to complete.
350                    // Connection must still be polled for the shutdown to proceed.
351                    // We don't really care for the result.
352                    let _ = timeout(self.options.grace_period, conn.as_mut()).await;
353                    break;
354                },
355
356                // Get request state change notifications
357                Some(v) = state_rx.recv() => {
358                    match v {
359                        RequestState::Start => {
360                            let reqs = self.requests.fetch_add(1, Ordering::SeqCst) + 1;
361                            debug!("{self}: request started");
362
363                            // Effectively disable the timer by setting it to 10 years into the future.
364                            // TODO improve?
365                            if self.options.idle_timeout.is_some() {
366                                debug!("{self}: stopping idle timer (now: {reqs})");
367                                idle_timer.as_mut().reset(tokio::time::Instant::now() + 10 * YEAR);
368                            }
369                        },
370
371                        RequestState::End => {
372                            let reqs = self.requests.fetch_sub(1, Ordering::SeqCst) - 1;
373                            debug!("{self}: request finished (now: {reqs})");
374
375                            // Check if the number of outstanding requests is now zero
376                            if let Some(v) = self.options.idle_timeout && reqs == 0 {
377                                // Enable the idle timer
378                                debug!("{self}: no outstanding requests, starting timer");
379                                idle_timer.as_mut().reset(tokio::time::Instant::now() + v);
380                            }
381                        }
382                    }
383                },
384
385                // See if the idle timeout has kicked in
386                () = idle_timer.as_mut(), if self.options.idle_timeout.is_some() => {
387                    debug!("{self}: Idle timeout triggered, closing");
388
389                    // Signal that we're closing
390                    conn.as_mut().graceful_shutdown();
391                    // Give the client some time to shut down
392                    let _ = timeout(Duration::from_secs(5), conn.as_mut()).await;
393                    break;
394                },
395
396                // Drive the connection by polling it
397                v = conn.as_mut() => {
398                    if let Err(e) = v {
399                        return Err(anyhow!("unable to serve connection: {e:#}").into());
400                    }
401
402                    break;
403                },
404            }
405        }
406
407        Ok(())
408    }
409}
410
411/// Builder for a `Server`
412pub struct ServerBuilder {
413    addr: Option<Addr>,
414    router: Router,
415    registry: Registry,
416    metrics: Option<Metrics>,
417    options: ServerOptions,
418    rustls_cfg: Option<rustls::ServerConfig>,
419}
420
421impl ServerBuilder {
422    /// Creates a builder with a given router & defaults
423    pub fn new(router: Router) -> Self {
424        Self {
425            addr: None,
426            router,
427            registry: Registry::new(),
428            metrics: None,
429            options: ServerOptions::default(),
430            rustls_cfg: None,
431        }
432    }
433
434    /// Listens on the given TCP socket
435    pub fn listen_tcp(mut self, socket: SocketAddr) -> Self {
436        self.addr = Some(Addr::Tcp(socket));
437        self
438    }
439
440    /// Listens on the given Unix socket
441    pub fn listen_unix(mut self, path: PathBuf) -> Self {
442        self.addr = Some(Addr::Unix(path));
443        self
444    }
445
446    /// Sets up metrics with provided Registry
447    pub fn with_metrics_registry(mut self, registry: &Registry) -> Self {
448        self.registry = registry.clone();
449        self
450    }
451
452    /// Sets up metrics with provided Metrics.
453    /// Overrides `with_metrics_registry()`.
454    pub fn with_metrics(mut self, metrics: Metrics) -> Self {
455        self.metrics = Some(metrics);
456        self
457    }
458
459    /// Sets up TLS with provided ServerConfig
460    pub fn with_rustls_config(mut self, rustls_cfg: rustls::ServerConfig) -> Self {
461        self.rustls_cfg = Some(rustls_cfg);
462        self
463    }
464
465    /// Sets up with provided Options
466    pub const fn with_options(mut self, options: ServerOptions) -> Self {
467        self.options = options;
468        self
469    }
470
471    /// Sets up TLS with a single certificate.
472    /// If metrics are needed - provide registry using `with_metrics_registry` before calling this method.
473    pub fn with_rustls_single_cert(mut self, cert: PathBuf, key: PathBuf) -> Result<Self, Error> {
474        let cert = std::fs::read(cert).context("unable to read cert")?;
475        let key = std::fs::read(key).context("unable to read key")?;
476        let cert = pem_convert_to_rustls(&key, &cert).context("unable to parse cert+key pair")?;
477        let resolver = SingleCertAndKey::from(cert);
478        let tls_opts = TlsOptions::default();
479        let rustls_cfg = prepare_server_config(tls_opts, Arc::new(resolver), &self.registry);
480
481        self.rustls_cfg = Some(rustls_cfg);
482        Ok(self)
483    }
484
485    /// Build the Server
486    pub fn build(self) -> Result<Server, Error> {
487        let Some(addr) = self.addr else {
488            return Err(Error::Generic(anyhow!("Listening address not specified")));
489        };
490
491        let metrics = self.metrics.unwrap_or_else(|| Metrics::new(&self.registry));
492
493        Ok(Server::new(
494            addr,
495            self.router,
496            self.options,
497            metrics,
498            self.rustls_cfg,
499        ))
500    }
501}
502
503/// Listens for new connections with an optional TLS and serves provided Router
504pub struct Server {
505    addr: Addr,
506    router: Router,
507    tracker: TaskTracker,
508    options: ServerOptions,
509    metrics: Metrics,
510    builder: Builder<TokioExecutor>,
511    rustls_cfg: Option<Arc<rustls::ServerConfig>>,
512}
513
514impl Display for Server {
515    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516        write!(f, "[{}]", self.addr)
517    }
518}
519
520impl Server {
521    /// Create a new `Server`
522    pub fn new(
523        addr: Addr,
524        router: Router,
525        options: ServerOptions,
526        metrics: Metrics,
527        rustls_cfg: Option<rustls::ServerConfig>,
528    ) -> Self {
529        // Prepare Hyper connection builder
530        // It automatically figures out whether to do HTTP1 or HTTP2
531        let mut builder = Builder::new(TokioExecutor::new());
532        builder
533            .http1()
534            .timer(TokioTimer::new()) // Needed for the keepalives below
535            .header_read_timeout(Some(options.http1_header_read_timeout))
536            .keep_alive(true)
537            .http2()
538            .adaptive_window(true)
539            .max_concurrent_streams(Some(options.http2_max_streams))
540            .timer(TokioTimer::new()) // Needed for the keepalives below
541            .keep_alive_interval(options.http2_keepalive_interval)
542            .keep_alive_timeout(options.http2_keepalive_timeout)
543            .enable_connect_protocol(); // Needed for Websockets
544
545        Self {
546            addr,
547            router,
548            options,
549            metrics,
550            tracker: TaskTracker::new(),
551            builder,
552            rustls_cfg: rustls_cfg.map(Arc::new),
553        }
554    }
555
556    /// Start serving with given cancellation token
557    pub async fn serve(&self, token: CancellationToken) -> Result<(), Error> {
558        let opts = ListenerOpts {
559            backlog: self.options.backlog,
560            mss: self.options.tcp_mss,
561            keepalive: (&self.options).into(),
562        };
563
564        let listener =
565            Listener::new(self.addr.clone(), opts).context("unable to create listener")?;
566        self.serve_with_listener(listener, token).await
567    }
568
569    fn spawn_connection(
570        &self,
571        stream: Box<dyn AsyncReadWrite>,
572        remote_addr: Addr,
573        token: CancellationToken,
574    ) {
575        // Create a new connection
576        // Router & TlsAcceptor are both Arc<> inside so it's cheap to clone
577        // Builder is a bit more complex, but cloning is better than to create it again
578        let conn = Conn {
579            addr: self.addr.clone(),
580            remote_addr: remote_addr.clone(),
581            router: self.router.clone(),
582            builder: self.builder.clone(),
583            token_graceful: token,
584            token_forceful: CancellationToken::new(),
585            options: self.options,
586            metrics: self.metrics.clone(), // All metrics have Arc inside
587            requests: AtomicU32::new(0),
588            rustls_cfg: self.rustls_cfg.clone(),
589        };
590
591        // Spawn a task to handle connection & track it
592        self.tracker.spawn(async move {
593            if let Err(e) = conn.handle(stream).await {
594                info!(
595                    "[{}] <- [{remote_addr}]: failed to handle connection: {e:#}",
596                    conn.addr
597                );
598            }
599
600            debug!("[{}] <- [{remote_addr}]: connection finished", conn.addr);
601        });
602    }
603
604    /// Start serving with a given listener & cancellation token
605    pub async fn serve_with_listener(
606        &self,
607        listener: Listener,
608        token: CancellationToken,
609    ) -> Result<(), Error> {
610        warn!("{self}: running (TLS: {})", self.rustls_cfg.is_some());
611
612        loop {
613            select! {
614                biased; // Poll top-down
615
616                () = token.cancelled() => {
617                    // Stop accepting new connections
618                    drop(listener);
619
620                    warn!("{self}: shutting down, waiting for the active connections to close for {}s", self.options.grace_period.as_secs());
621                    self.tracker.close();
622
623                    select! {
624                        () = sleep(self.options.grace_period + Duration::from_secs(5)) => {
625                            warn!("{self}: connections didn't close in time, shutting down anyway");
626                        },
627                        () = self.tracker.wait() => {},
628                    }
629
630                    warn!("{self}: shut down");
631
632                    // Remove the socket
633                    if let Addr::Unix(v) = &self.addr {
634                        let _ = std::fs::remove_file(v);
635                    }
636
637                    return Ok(());
638                },
639
640                // Try to accept the connection
641                v = listener.accept() => {
642                    let (stream, remote_addr) = match v {
643                        Ok(v) => v,
644                        Err(e) => {
645                            warn!("{self}: unable to accept connection: {e:#}");
646                            // Wait few ms just in case that there's an overflown backlog
647                            // so that we don't run into infinite error loop
648                            sleep(Duration::from_millis(10)).await;
649                            continue;
650                        }
651                    };
652
653                    self.spawn_connection(stream, remote_addr, token.child_token());
654                }
655            }
656        }
657    }
658}
659
660#[async_trait]
661impl Run for Server {
662    async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> {
663        self.serve(token).await?;
664        Ok(())
665    }
666}
667
668#[cfg(test)]
669mod test {
670    use http::StatusCode;
671
672    use crate::network::listener::listen_tcp;
673
674    use super::*;
675
676    #[tokio::test]
677    async fn test_server() {
678        let opts = ServerOptions::default();
679        let listener = listen_tcp(
680            "127.0.0.1:0".parse().unwrap(),
681            ListenerOpts {
682                backlog: 128,
683                mss: None,
684                keepalive: (&opts).into(),
685            },
686        )
687        .unwrap();
688
689        let addr = listener.local_addr().unwrap();
690
691        let server = Server::new(
692            Addr::Tcp(addr),
693            Router::new(),
694            opts,
695            Metrics::new(&Registry::new()),
696            None,
697        );
698
699        tokio::spawn(async move {
700            server
701                .serve_with_listener(listener.into(), CancellationToken::new())
702                .await
703                .unwrap();
704        });
705
706        for _ in 0..10 {
707            let Ok(result) = reqwest::get(format!("http://{addr}")).await else {
708                tokio::time::sleep(Duration::from_millis(10)).await;
709                continue;
710            };
711
712            assert_eq!(result.status(), StatusCode::NOT_FOUND);
713            break;
714        }
715    }
716}