tower_server/lib.rs
1//! High-level hyper server interfacing with tower-service.
2//!
3//! ## Features:
4//! * `rustls` integration
5//! * Graceful shutdown using `CancellationToken` from `tokio_util`.
6//! * Optional connnection middleware for handling the remote address
7//! * Dynamic TLS reconfiguration without restarting server, for e.g. certificate rotation
8//! * Optional TLS connection middleware, for example for mTLS integration
9//!
10//! ## Example usage using Axum with graceful shutdown:
11//!
12//! ```rust
13//! # async fn serve() {
14//! #[cfg(feature = "signal")]
15//! // Uses the built-in termination signal:
16//! let shutdown_token = tower_server::signal::termination_signal();
17//!
18//! #[cfg(not(feature = "signal"))]
19//! // Configure the shutdown token manually:
20//! let shutdown_token = tokio_util::sync::CancellationToken::default();
21//!
22//! let server = tower_server::Builder::new("0.0.0.0:8080".parse().unwrap())
23//! .with_graceful_shutdown(shutdown_token)
24//! .bind()
25//! .await
26//! .unwrap();
27//!
28//! server.serve(axum::Router::new()).await;
29//! # }
30//! ```
31//!
32//! ## Example using connection middleware
33//!
34//! ```rust
35//! #[derive(Clone)]
36//! struct RemoteAddr(std::net::SocketAddr);
37//!
38//! # async fn serve() {
39//! let server = tower_server::Builder::new("0.0.0.0:8080".parse().unwrap())
40//! .with_connection_middleware(|req, remote_addr| {
41//! req.extensions_mut().insert(RemoteAddr(remote_addr));
42//! })
43//! .bind()
44//! .await
45//! .unwrap();
46//!
47//! server.serve(axum::Router::new()).await;
48//! # }
49//! ```
50//!
51//! ## Example using TLS connection middleware
52//!
53//! ```rust
54//! # use std::sync::Arc;
55//! use rustls_pki_types::CertificateDer;
56//! use hyper::body::Incoming;
57//!
58//! #[derive(Clone)]
59//! struct PeerCertMiddleware;
60//!
61//! /// A request extension that includes the mTLS peer certificate
62//! #[derive(Clone)]
63//! struct PeerCertificate(CertificateDer<'static>);
64//!
65//! impl tower_server::tls::TlsConnectionMiddleware for PeerCertMiddleware {
66//! type Data = Option<PeerCertificate>;
67//!
68//! /// Step 1: Extract data from the rustls server connection.
69//! /// At this stage of TLS handshake the http::Request doesn't yet exist.
70//! fn data(&self, connection: &rustls::ServerConnection) -> Self::Data {
71//! Some(PeerCertificate(connection.peer_certificates()?.first()?.clone()))
72//! }
73//!
74//! /// Step 2: The http::Request now exists, and the request extension can be injected.
75//! fn call(&self, req: &mut http::Request<Incoming>, data: &Option<PeerCertificate>) {
76//! if let Some(peer_certificate) = data {
77//! req.extensions_mut().insert(peer_certificate.clone());
78//! }
79//! }
80//! }
81//!
82//! # async fn serve() {
83//! let server = tower_server::Builder::new("0.0.0.0:443".parse().unwrap())
84//! .with_scheme(tower_server::Scheme::Https)
85//! .with_tls_connection_middleware(PeerCertMiddleware)
86//! .with_tls_config(
87//! rustls::server::ServerConfig::builder()
88//! // Instead of this, actually configure client authentication here:
89//! .with_no_client_auth()
90//! // just a compiling example for setting a cert resolver, replace this with your actual config:
91//! .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new()))
92//! )
93//! .bind()
94//! .await
95//! .unwrap();
96//!
97//! server.serve(axum::Router::new()).await;
98//! # }
99//! ```
100//!
101//! ## Example using dynamically chaning TLS configuration
102//! [tls::TlsConfigurer] is implemented for [futures_util::stream::BoxStream] of [Arc]ed [rustls::server::ServerConfig]s:
103//!
104//! ```rust
105//! # use std::sync::Arc;
106//! # use std::time::Duration;
107//! use futures_util::StreamExt;
108//!
109//! # async fn serve() {
110//! let initial_tls_config = Arc::new(
111//! rustls::server::ServerConfig::builder()
112//! .with_no_client_auth()
113//! .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new()))
114//! );
115//!
116//! let tls_config_rotation = futures_util::stream::unfold((), |_| async move {
117//! // renews after a fixed delay:
118//! tokio::time::sleep(Duration::from_secs(10)).await;
119//!
120//! // just for illustration purposes, replace with your own ServerConfig:
121//! let renewed_config = Arc::new(
122//! rustls::server::ServerConfig::builder()
123//! .with_no_client_auth()
124//! .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new()))
125//! );
126//!
127//! Some((renewed_config, ()))
128//! });
129//!
130//! let server = tower_server::Builder::new("0.0.0.0:443".parse().unwrap())
131//! .with_scheme(tower_server::Scheme::Https)
132//! .with_tls_config(
133//! // takes the initial config, which resolves without delay,
134//! // chained together with the subsequent dynamic updates:
135//! futures_util::stream::iter([initial_tls_config])
136//! .chain(tls_config_rotation)
137//! .boxed()
138//! )
139//! .bind()
140//! .await
141//! .unwrap();
142//!
143//! server.serve(axum::Router::new()).await;
144//! # }
145//! ```
146
147#![forbid(unsafe_code)]
148#![warn(missing_docs)]
149#![cfg_attr(feature = "unstable", feature(doc_auto_cfg))]
150
151use std::net::SocketAddr;
152use std::{error::Error as StdError, sync::Arc};
153
154use arc_swap::ArcSwap;
155use futures_util::future::poll_fn;
156use futures_util::stream::BoxStream;
157use futures_util::StreamExt;
158use hyper::body::Incoming;
159use hyper_util::rt::{TokioExecutor, TokioIo};
160use pin_utils::pin_mut;
161use rustls::ServerConfig;
162use tls::{NoOpTlsConnectionMiddleware, TlsConfigurer, TlsConnectionMiddleware};
163use tokio::net::TcpListener;
164use tokio::sync::watch;
165use tokio_rustls::TlsAcceptor;
166use tokio_util::sync::CancellationToken;
167use tracing::{info, trace};
168
169pub mod tls;
170
171#[cfg(feature = "signal")]
172pub mod signal;
173
174/// Server configuration.
175pub struct Builder<TlsM> {
176 addr: SocketAddr,
177 scheme: Scheme,
178 cancel: CancellationToken,
179 connection_middleware: fn(&mut http::Request<Incoming>, SocketAddr),
180 tls_connection_middleware: TlsM,
181 tls_config_is_dynamic: bool,
182 tls_config_stream: BoxStream<'static, Arc<rustls::server::ServerConfig>>,
183}
184
185impl Builder<NoOpTlsConnectionMiddleware> {
186 /// Configure using a socket addr using the Http scheme.
187 pub fn new(addr: SocketAddr) -> Self {
188 Self {
189 addr,
190 scheme: Scheme::Http,
191 cancel: Default::default(),
192 connection_middleware: |_, _| {},
193 tls_connection_middleware: NoOpTlsConnectionMiddleware,
194 tls_config_is_dynamic: false,
195 tls_config_stream: futures_util::stream::empty().boxed(),
196 }
197 }
198
199 /// Configure the tower server from a server Url with auto-configuration of http scheme.
200 #[cfg(feature = "url")]
201 pub fn from_url(base_url: url::Url) -> anyhow::Result<Self> {
202 use anyhow::anyhow;
203
204 let port = base_url
205 .port_or_known_default()
206 .ok_or_else(|| anyhow!("server port not deducible from base url"))?;
207 let addr: SocketAddr = match base_url.host() {
208 // treat domain name as binding on every interface
209 Some(url::Host::Domain(_)) => ([0, 0, 0, 0], port).into(),
210 Some(url::Host::Ipv4(v4)) => (v4, port).into(),
211 Some(url::Host::Ipv6(v6)) => (v6, port).into(),
212 None => return Err(anyhow!("no host in url")),
213 };
214
215 Ok(Self {
216 addr,
217 cancel: Default::default(),
218 connection_middleware: |_, _| {},
219 tls_connection_middleware: NoOpTlsConnectionMiddleware,
220 scheme: match base_url.scheme() {
221 "http" => Scheme::Http,
222 "https" => Scheme::Https,
223 scheme => return Err(anyhow!("unknown http server scheme: {scheme}")),
224 },
225 tls_config_is_dynamic: false,
226 tls_config_stream: futures_util::stream::empty().boxed(),
227 })
228 }
229}
230
231impl<TlsM> Builder<TlsM> {
232 /// Set the scheme used by the the server. A Https scheme requires a TLS config factory.
233 pub fn with_scheme(mut self, scheme: Scheme) -> Self {
234 self.scheme = scheme;
235 self
236 }
237
238 /// Register a function that acts a connection middleware on any accepted connection.
239 /// The middleware is able to modify every incoming request.
240 pub fn with_connection_middleware(mut self, middleware: ConnectionMiddleware) -> Self {
241 self.connection_middleware = middleware;
242 self
243 }
244
245 /// Register a TLS configurator.
246 /// TLS configuration will only be invoked when Scheme is set to Https.
247 pub fn with_tls_config(mut self, tls: impl TlsConfigurer) -> Self {
248 self.tls_config_is_dynamic = tls.is_dynamic();
249 self.tls_config_stream = tls.into_stream();
250 self
251 }
252
253 /// Register a TLS connection middleware.
254 pub fn with_tls_connection_middleware<T: TlsConnectionMiddleware>(
255 self,
256 middleware: T,
257 ) -> Builder<T> {
258 Builder {
259 addr: self.addr,
260 connection_middleware: self.connection_middleware,
261 tls_connection_middleware: middleware,
262 scheme: self.scheme,
263 cancel: self.cancel,
264 tls_config_is_dynamic: self.tls_config_is_dynamic,
265 tls_config_stream: self.tls_config_stream,
266 }
267 }
268
269 /// Register a cancellation token that enables graceful shutdown.
270 pub fn with_graceful_shutdown(mut self, cancel: CancellationToken) -> Self {
271 self.cancel = cancel;
272 self
273 }
274
275 /// Build server and bind it to the configured address.
276 pub async fn bind(self) -> anyhow::Result<TowerServer<TlsM>> {
277 let mut tls_config_stream = self.tls_config_stream;
278
279 let tls_config_swap = match self.scheme {
280 Scheme::Http => None,
281 Scheme::Https => {
282 let initial_tls_config = tls_config_stream.next().await.unwrap_or_else(|| {
283 panic!("Https scheme detected, but no TLS config registered")
284 });
285
286 let swap = Arc::new(ArcSwap::new(initial_tls_config));
287
288 // set up subscription for dynamically changing TLS config
289 if self.tls_config_is_dynamic {
290 let cancel = self.cancel.clone();
291 let swap = swap.clone();
292
293 tokio::spawn(async move {
294 loop {
295 tokio::select! {
296 next_tls_config = tls_config_stream.next() => {
297 if let Some(tls_config) = next_tls_config {
298 tracing::info!("renewing TLS ServerConfig");
299 swap.store(tls_config);
300 } else {
301 return;
302 }
303 }
304 _ = cancel.cancelled() => {
305 return;
306 }
307 }
308 }
309 });
310 }
311
312 Some(swap)
313 }
314 };
315
316 let listener = TcpListener::bind(self.addr).await?;
317
318 Ok(TowerServer {
319 listener,
320 tls_config_swap,
321 cancel: self.cancel,
322 connection_middleware: self.connection_middleware,
323 tls_connection_middleware: self.tls_connection_middleware,
324 })
325 }
326}
327
328/// Desired HTTP scheme.
329#[derive(Clone, Copy)]
330pub enum Scheme {
331 /// HTTP without TLS.
332 Http,
333 /// HTTP with TLS.
334 Https,
335}
336
337/// The type of the connection middleware.
338///
339/// It is a function which receives a mutable request and a [SocketAddr] representing the remote client.
340pub type ConnectionMiddleware = fn(&mut http::Request<Incoming>, SocketAddr);
341
342/// A bound server, ready for running accept-loop using a tower service.
343pub struct TowerServer<TlsM = NoOpTlsConnectionMiddleware> {
344 listener: TcpListener,
345 tls_config_swap: Option<Arc<ArcSwap<ServerConfig>>>,
346 cancel: CancellationToken,
347 connection_middleware: fn(&mut http::Request<Incoming>, SocketAddr),
348 tls_connection_middleware: TlsM,
349}
350
351impl<TlsM> TowerServer<TlsM> {
352 /// Access the locally bound address
353 pub fn local_addr(&self) -> anyhow::Result<SocketAddr> {
354 self.listener.local_addr().map_err(|e| e.into())
355 }
356
357 /// Run HTTP accept loop, handling every request using the passwed tower service.
358 pub async fn serve<S, B>(self, tower_service: S)
359 where
360 S: tower_service::Service<
361 http::Request<hyper::body::Incoming>,
362 Response = http::Response<B>,
363 >
364 + Send
365 + Sync
366 + 'static
367 + Clone,
368 S::Future: 'static + Send,
369 S::Error: Into<Box<dyn StdError + Send + Sync + 'static>>,
370 B: http_body::Body + Send + 'static,
371 B::Data: Send,
372 B::Error: Into<Box<dyn StdError + Send + Sync + 'static>>,
373 TlsM: TlsConnectionMiddleware,
374 {
375 // tracks how long to gracefully await shutdown.
376 // Nothing is ever sent on this channel, it's only used for
377 // tracking the number of live receivers.
378 // each active connection has a clone of `close_rx`,
379 // at the end of the function `close_tx.closed()` is awaited,
380 // which finishes when no receivers are available.
381 let (close_tx, close_rx) = watch::channel(());
382
383 // accept loop
384 loop {
385 let (tcp_stream, remote_addr) = tokio::select! {
386 accept = self.listener.accept() => {
387 match accept {
388 Ok(stream_addr) => stream_addr,
389 Err(_) => {
390 continue;
391 }
392 }
393 }
394 _ = self.cancel.cancelled() => {
395 trace!("signal received, not accepting new connections");
396 break;
397 }
398 };
399
400 let tls_config_swap = self.tls_config_swap.clone();
401 let close_rx = close_rx.clone();
402 let cancel = self.cancel.clone();
403 let connection_middleware = self.connection_middleware;
404 let tls_connection_middleware = self.tls_connection_middleware.clone();
405 let tower_service = tower_service.clone();
406
407 tokio::spawn(async move {
408 let connection_builder =
409 hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
410 match tls_config_swap {
411 None => {
412 let connection = connection_builder.serve_connection_with_upgrades(
413 TokioIo::new(tcp_stream),
414 hyper::service::service_fn(move |mut req| {
415 connection_middleware(&mut req, remote_addr);
416 let mut tower_service = tower_service.clone();
417
418 async move {
419 poll_fn(|cx| tower_service.poll_ready(cx)).await?;
420 tower_service.call(req).await
421 }
422 }),
423 );
424 pin_mut!(connection);
425 tokio::select! {
426 biased;
427 _ = connection.as_mut() => {}
428 _ = cancel.cancelled() => {
429 connection.as_mut().graceful_shutdown();
430 let _ = connection.as_mut().await;
431 }
432 }
433 }
434 Some(tls_config_swap) => {
435 let tls_acceptor = TlsAcceptor::from(tls_config_swap.load_full());
436 let tls_stream = match tls_acceptor.accept(tcp_stream).await {
437 Ok(tls_stream) => tls_stream,
438 Err(err) => {
439 info!(?err, "failed to perform tls handshake");
440 return;
441 }
442 };
443
444 let tls_middleware_data =
445 tls_connection_middleware.data(tls_stream.get_ref().1);
446
447 let connection = connection_builder.serve_connection_with_upgrades(
448 TokioIo::new(tls_stream),
449 hyper::service::service_fn(move |mut req| {
450 connection_middleware(&mut req, remote_addr);
451 tls_connection_middleware.call(&mut req, &tls_middleware_data);
452 let mut tower_service = tower_service.clone();
453
454 async move {
455 poll_fn(|cx| tower_service.poll_ready(cx)).await?;
456 tower_service.call(req).await
457 }
458 }),
459 );
460
461 pin_mut!(connection);
462 tokio::select! {
463 biased;
464 _ = connection.as_mut() => {}
465 _ = cancel.cancelled() => {
466 connection.as_mut().graceful_shutdown();
467 let _ = connection.as_mut().await;
468 }
469 }
470 }
471 }
472
473 drop(close_rx);
474 });
475 }
476
477 drop(close_rx);
478 trace!(
479 "waiting for {} task(s) to finish",
480 close_tx.receiver_count()
481 );
482 close_tx.closed().await;
483 }
484}