tower_server/
lib.rs

1//! High-level hyper server interfacing with tower-service.
2//!
3//! ## Features:
4//! * `rustls` integration
5//! * Graceful shutdown using `CancellationToken` from `tokio_util`.
6//! * Optional connnection middleware for handling the remote address
7//! * Dynamic TLS reconfiguration without restarting server, for e.g. certificate rotation
8//! * Optional TLS connection middleware, for example for mTLS integration
9//!
10//! ## Example usage using Axum with graceful shutdown:
11//!
12//! ```rust
13//! # async fn serve() {
14//! #[cfg(feature = "signal")]
15//! // Uses the built-in termination signal:
16//! let shutdown_token = tower_server::signal::termination_signal();
17//!
18//! #[cfg(not(feature = "signal"))]
19//! // Configure the shutdown token manually:
20//! let shutdown_token = tokio_util::sync::CancellationToken::default();
21//!
22//! let server = tower_server::Builder::new("0.0.0.0:8080".parse().unwrap())
23//!     .with_graceful_shutdown(shutdown_token)
24//!     .bind()
25//!     .await
26//!     .unwrap();
27//!
28//! server.serve(axum::Router::new()).await;
29//! # }
30//! ```
31//!
32//! ## Example using connection middleware
33//!
34//! ```rust
35//! #[derive(Clone)]
36//! struct RemoteAddr(std::net::SocketAddr);
37//!
38//! # async fn serve() {
39//! let server = tower_server::Builder::new("0.0.0.0:8080".parse().unwrap())
40//!     .with_connection_middleware(|req, remote_addr| {
41//!         req.extensions_mut().insert(RemoteAddr(remote_addr));
42//!     })
43//!     .bind()
44//!     .await
45//!     .unwrap();
46//!
47//! server.serve(axum::Router::new()).await;
48//! # }
49//! ```
50//!
51//! ## Example using TLS connection middleware
52//!
53//! ```rust
54//! # use std::sync::Arc;
55//! use rustls_pki_types::CertificateDer;
56//! use hyper::body::Incoming;
57//!
58//! #[derive(Clone)]
59//! struct PeerCertMiddleware;
60//!
61//! /// A request extension that includes the mTLS peer certificate
62//! #[derive(Clone)]
63//! struct PeerCertificate(CertificateDer<'static>);
64//!
65//! impl tower_server::tls::TlsConnectionMiddleware for PeerCertMiddleware {
66//!     type Data = Option<PeerCertificate>;
67//!
68//!     /// Step 1: Extract data from the rustls server connection.
69//!     /// At this stage of TLS handshake the http::Request doesn't yet exist.
70//!     fn data(&self, connection: &rustls::ServerConnection) -> Self::Data {
71//!         Some(PeerCertificate(connection.peer_certificates()?.first()?.clone()))
72//!     }
73//!
74//!     /// Step 2: The http::Request now exists, and the request extension can be injected.
75//!     fn call(&self, req: &mut http::Request<Incoming>, data: &Option<PeerCertificate>) {
76//!         if let Some(peer_certificate) = data {
77//!             req.extensions_mut().insert(peer_certificate.clone());
78//!         }
79//!     }
80//! }
81//!
82//! # async fn serve() {
83//! let server = tower_server::Builder::new("0.0.0.0:443".parse().unwrap())
84//!     .with_scheme(tower_server::Scheme::Https)
85//!     .with_tls_connection_middleware(PeerCertMiddleware)
86//!     .with_tls_config(
87//!         rustls::server::ServerConfig::builder()
88//!             // Instead of this, actually configure client authentication here:
89//!             .with_no_client_auth()
90//!             // just a compiling example for setting a cert resolver, replace this with your actual config:
91//!             .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new()))
92//!     )
93//!     .bind()
94//!     .await
95//!     .unwrap();
96//!
97//! server.serve(axum::Router::new()).await;
98//! # }
99//! ```
100//!
101//! ## Example using dynamically chaning TLS configuration
102//! [tls::TlsConfigurer] is implemented for [futures_util::stream::BoxStream] of [Arc]ed [rustls::server::ServerConfig]s:
103//!
104//! ```rust
105//! # use std::sync::Arc;
106//! # use std::time::Duration;
107//! use futures_util::StreamExt;
108//!
109//! # async fn serve() {
110//! let initial_tls_config = Arc::new(
111//!     rustls::server::ServerConfig::builder()
112//!         .with_no_client_auth()
113//!         .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new()))
114//! );
115//!
116//! let tls_config_rotation = futures_util::stream::unfold((), |_| async move {
117//!     // renews after a fixed delay:
118//!     tokio::time::sleep(Duration::from_secs(10)).await;
119//!
120//!     // just for illustration purposes, replace with your own ServerConfig:
121//!     let renewed_config = Arc::new(
122//!         rustls::server::ServerConfig::builder()
123//!             .with_no_client_auth()
124//!             .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new()))
125//!     );
126//!
127//!     Some((renewed_config, ()))
128//! });
129//!
130//! let server = tower_server::Builder::new("0.0.0.0:443".parse().unwrap())
131//!     .with_scheme(tower_server::Scheme::Https)
132//!     .with_tls_config(
133//!         // takes the initial config, which resolves without delay,
134//!         // chained together with the subsequent dynamic updates:
135//!         futures_util::stream::iter([initial_tls_config])
136//!             .chain(tls_config_rotation)
137//!             .boxed()
138//!     )
139//!     .bind()
140//!     .await
141//!     .unwrap();
142//!
143//! server.serve(axum::Router::new()).await;
144//! # }
145//! ```
146
147#![forbid(unsafe_code)]
148#![warn(missing_docs)]
149#![cfg_attr(feature = "unstable", feature(doc_auto_cfg))]
150
151use std::net::SocketAddr;
152use std::{error::Error as StdError, sync::Arc};
153
154use arc_swap::ArcSwap;
155use futures_util::future::poll_fn;
156use futures_util::stream::BoxStream;
157use futures_util::StreamExt;
158use hyper::body::Incoming;
159use hyper_util::rt::{TokioExecutor, TokioIo};
160use pin_utils::pin_mut;
161use rustls::ServerConfig;
162use tls::{NoOpTlsConnectionMiddleware, TlsConfigurer, TlsConnectionMiddleware};
163use tokio::net::TcpListener;
164use tokio::sync::watch;
165use tokio_rustls::TlsAcceptor;
166use tokio_util::sync::CancellationToken;
167use tracing::{info, trace};
168
169pub mod tls;
170
171#[cfg(feature = "signal")]
172pub mod signal;
173
174/// Server configuration.
175pub struct Builder<TlsM> {
176    addr: SocketAddr,
177    scheme: Scheme,
178    cancel: CancellationToken,
179    connection_middleware: fn(&mut http::Request<Incoming>, SocketAddr),
180    tls_connection_middleware: TlsM,
181    tls_config_is_dynamic: bool,
182    tls_config_stream: BoxStream<'static, Arc<rustls::server::ServerConfig>>,
183}
184
185impl Builder<NoOpTlsConnectionMiddleware> {
186    /// Configure using a socket addr using the Http scheme.
187    pub fn new(addr: SocketAddr) -> Self {
188        Self {
189            addr,
190            scheme: Scheme::Http,
191            cancel: Default::default(),
192            connection_middleware: |_, _| {},
193            tls_connection_middleware: NoOpTlsConnectionMiddleware,
194            tls_config_is_dynamic: false,
195            tls_config_stream: futures_util::stream::empty().boxed(),
196        }
197    }
198
199    /// Configure the tower server from a server Url with auto-configuration of http scheme.
200    #[cfg(feature = "url")]
201    pub fn from_url(base_url: url::Url) -> anyhow::Result<Self> {
202        use anyhow::anyhow;
203
204        let port = base_url
205            .port_or_known_default()
206            .ok_or_else(|| anyhow!("server port not deducible from base url"))?;
207        let addr: SocketAddr = match base_url.host() {
208            // treat domain name as binding on every interface
209            Some(url::Host::Domain(_)) => ([0, 0, 0, 0], port).into(),
210            Some(url::Host::Ipv4(v4)) => (v4, port).into(),
211            Some(url::Host::Ipv6(v6)) => (v6, port).into(),
212            None => return Err(anyhow!("no host in url")),
213        };
214
215        Ok(Self {
216            addr,
217            cancel: Default::default(),
218            connection_middleware: |_, _| {},
219            tls_connection_middleware: NoOpTlsConnectionMiddleware,
220            scheme: match base_url.scheme() {
221                "http" => Scheme::Http,
222                "https" => Scheme::Https,
223                scheme => return Err(anyhow!("unknown http server scheme: {scheme}")),
224            },
225            tls_config_is_dynamic: false,
226            tls_config_stream: futures_util::stream::empty().boxed(),
227        })
228    }
229}
230
231impl<TlsM> Builder<TlsM> {
232    /// Set the scheme used by the the server. A Https scheme requires a TLS config factory.
233    pub fn with_scheme(mut self, scheme: Scheme) -> Self {
234        self.scheme = scheme;
235        self
236    }
237
238    /// Register a function that acts a connection middleware on any accepted connection.
239    /// The middleware is able to modify every incoming request.
240    pub fn with_connection_middleware(mut self, middleware: ConnectionMiddleware) -> Self {
241        self.connection_middleware = middleware;
242        self
243    }
244
245    /// Register a TLS configurator.
246    /// TLS configuration will only be invoked when Scheme is set to Https.
247    pub fn with_tls_config(mut self, tls: impl TlsConfigurer) -> Self {
248        self.tls_config_is_dynamic = tls.is_dynamic();
249        self.tls_config_stream = tls.into_stream();
250        self
251    }
252
253    /// Register a TLS connection middleware.
254    pub fn with_tls_connection_middleware<T: TlsConnectionMiddleware>(
255        self,
256        middleware: T,
257    ) -> Builder<T> {
258        Builder {
259            addr: self.addr,
260            connection_middleware: self.connection_middleware,
261            tls_connection_middleware: middleware,
262            scheme: self.scheme,
263            cancel: self.cancel,
264            tls_config_is_dynamic: self.tls_config_is_dynamic,
265            tls_config_stream: self.tls_config_stream,
266        }
267    }
268
269    /// Register a cancellation token that enables graceful shutdown.
270    pub fn with_graceful_shutdown(mut self, cancel: CancellationToken) -> Self {
271        self.cancel = cancel;
272        self
273    }
274
275    /// Build server and bind it to the configured address.
276    pub async fn bind(self) -> anyhow::Result<TowerServer<TlsM>> {
277        let mut tls_config_stream = self.tls_config_stream;
278
279        let tls_config_swap = match self.scheme {
280            Scheme::Http => None,
281            Scheme::Https => {
282                let initial_tls_config = tls_config_stream.next().await.unwrap_or_else(|| {
283                    panic!("Https scheme detected, but no TLS config registered")
284                });
285
286                let swap = Arc::new(ArcSwap::new(initial_tls_config));
287
288                // set up subscription for dynamically changing TLS config
289                if self.tls_config_is_dynamic {
290                    let cancel = self.cancel.clone();
291                    let swap = swap.clone();
292
293                    tokio::spawn(async move {
294                        loop {
295                            tokio::select! {
296                                next_tls_config = tls_config_stream.next() => {
297                                    if let Some(tls_config) = next_tls_config {
298                                        tracing::info!("renewing TLS ServerConfig");
299                                        swap.store(tls_config);
300                                    } else {
301                                        return;
302                                    }
303                                }
304                                _ = cancel.cancelled() => {
305                                    return;
306                                }
307                            }
308                        }
309                    });
310                }
311
312                Some(swap)
313            }
314        };
315
316        let listener = TcpListener::bind(self.addr).await?;
317
318        Ok(TowerServer {
319            listener,
320            tls_config_swap,
321            cancel: self.cancel,
322            connection_middleware: self.connection_middleware,
323            tls_connection_middleware: self.tls_connection_middleware,
324        })
325    }
326}
327
328/// Desired HTTP scheme.
329#[derive(Clone, Copy)]
330pub enum Scheme {
331    /// HTTP without TLS.
332    Http,
333    /// HTTP with TLS.
334    Https,
335}
336
337/// The type of the connection middleware.
338///
339/// It is a function which receives a mutable request and a [SocketAddr] representing the remote client.
340pub type ConnectionMiddleware = fn(&mut http::Request<Incoming>, SocketAddr);
341
342/// A bound server, ready for running accept-loop using a tower service.
343pub struct TowerServer<TlsM = NoOpTlsConnectionMiddleware> {
344    listener: TcpListener,
345    tls_config_swap: Option<Arc<ArcSwap<ServerConfig>>>,
346    cancel: CancellationToken,
347    connection_middleware: fn(&mut http::Request<Incoming>, SocketAddr),
348    tls_connection_middleware: TlsM,
349}
350
351impl<TlsM> TowerServer<TlsM> {
352    /// Access the locally bound address
353    pub fn local_addr(&self) -> anyhow::Result<SocketAddr> {
354        self.listener.local_addr().map_err(|e| e.into())
355    }
356
357    /// Run HTTP accept loop, handling every request using the passwed tower service.
358    pub async fn serve<S, B>(self, tower_service: S)
359    where
360        S: tower_service::Service<
361                http::Request<hyper::body::Incoming>,
362                Response = http::Response<B>,
363            >
364            + Send
365            + Sync
366            + 'static
367            + Clone,
368        S::Future: 'static + Send,
369        S::Error: Into<Box<dyn StdError + Send + Sync + 'static>>,
370        B: http_body::Body + Send + 'static,
371        B::Data: Send,
372        B::Error: Into<Box<dyn StdError + Send + Sync + 'static>>,
373        TlsM: TlsConnectionMiddleware,
374    {
375        // tracks how long to gracefully await shutdown.
376        // Nothing is ever sent on this channel, it's only used for
377        // tracking the number of live receivers.
378        // each active connection has a clone of `close_rx`,
379        // at the end of the function `close_tx.closed()` is awaited,
380        // which finishes when no receivers are available.
381        let (close_tx, close_rx) = watch::channel(());
382
383        // accept loop
384        loop {
385            let (tcp_stream, remote_addr) = tokio::select! {
386                accept = self.listener.accept() => {
387                    match accept {
388                        Ok(stream_addr) => stream_addr,
389                        Err(_) => {
390                            continue;
391                        }
392                    }
393                }
394                _ = self.cancel.cancelled() => {
395                    trace!("signal received, not accepting new connections");
396                    break;
397                }
398            };
399
400            let tls_config_swap = self.tls_config_swap.clone();
401            let close_rx = close_rx.clone();
402            let cancel = self.cancel.clone();
403            let connection_middleware = self.connection_middleware;
404            let tls_connection_middleware = self.tls_connection_middleware.clone();
405            let tower_service = tower_service.clone();
406
407            tokio::spawn(async move {
408                let connection_builder =
409                    hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
410                match tls_config_swap {
411                    None => {
412                        let connection = connection_builder.serve_connection_with_upgrades(
413                            TokioIo::new(tcp_stream),
414                            hyper::service::service_fn(move |mut req| {
415                                connection_middleware(&mut req, remote_addr);
416                                let mut tower_service = tower_service.clone();
417
418                                async move {
419                                    poll_fn(|cx| tower_service.poll_ready(cx)).await?;
420                                    tower_service.call(req).await
421                                }
422                            }),
423                        );
424                        pin_mut!(connection);
425                        tokio::select! {
426                            biased;
427                            _ = connection.as_mut() => {}
428                            _ = cancel.cancelled() => {
429                                connection.as_mut().graceful_shutdown();
430                                let _ = connection.as_mut().await;
431                            }
432                        }
433                    }
434                    Some(tls_config_swap) => {
435                        let tls_acceptor = TlsAcceptor::from(tls_config_swap.load_full());
436                        let tls_stream = match tls_acceptor.accept(tcp_stream).await {
437                            Ok(tls_stream) => tls_stream,
438                            Err(err) => {
439                                info!(?err, "failed to perform tls handshake");
440                                return;
441                            }
442                        };
443
444                        let tls_middleware_data =
445                            tls_connection_middleware.data(tls_stream.get_ref().1);
446
447                        let connection = connection_builder.serve_connection_with_upgrades(
448                            TokioIo::new(tls_stream),
449                            hyper::service::service_fn(move |mut req| {
450                                connection_middleware(&mut req, remote_addr);
451                                tls_connection_middleware.call(&mut req, &tls_middleware_data);
452                                let mut tower_service = tower_service.clone();
453
454                                async move {
455                                    poll_fn(|cx| tower_service.poll_ready(cx)).await?;
456                                    tower_service.call(req).await
457                                }
458                            }),
459                        );
460
461                        pin_mut!(connection);
462                        tokio::select! {
463                            biased;
464                            _ = connection.as_mut() => {}
465                            _ = cancel.cancelled() => {
466                                connection.as_mut().graceful_shutdown();
467                                let _ = connection.as_mut().await;
468                            }
469                        }
470                    }
471                }
472
473                drop(close_rx);
474            });
475        }
476
477        drop(close_rx);
478        trace!(
479            "waiting for {} task(s) to finish",
480            close_tx.receiver_count()
481        );
482        close_tx.closed().await;
483    }
484}