hyper_server/tls_rustls/
mod.rs

1//! Tls implementation using [`rustls`].
2//!
3//! # Example
4//!
5//! ```rust,no_run
6//! use axum::{routing::get, Router};
7//! use hyper_server::tls_rustls::RustlsConfig;
8//! use std::net::SocketAddr;
9//!
10//! #[tokio::main]
11//! async fn main() {
12//!     let app = Router::new().route("/", get(|| async { "Hello, world!" }));
13//!
14//!     let config = RustlsConfig::from_pem_file(
15//!         "examples/self-signed-certs/cert.pem",
16//!         "examples/self-signed-certs/key.pem",
17//!     )
18//!     .await
19//!     .unwrap();
20//!
21//!     let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
22//!     println!("listening on {}", addr);
23//!     hyper_server::bind_rustls(addr, config)
24//!         .serve(app.into_make_service())
25//!         .await
26//!         .unwrap();
27//! }
28//! ```
29
30use self::future::RustlsAcceptorFuture;
31use crate::{
32    accept::{Accept, DefaultAcceptor},
33    server::{io_other, Server},
34};
35use arc_swap::ArcSwap;
36use rustls::{Certificate, PrivateKey, ServerConfig};
37use std::time::Duration;
38use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc};
39use tokio::{
40    io::{AsyncRead, AsyncWrite},
41    task::spawn_blocking,
42};
43use tokio_rustls::server::TlsStream;
44
45/// Sub-module that contains re-exported public interfaces.
46pub(crate) mod export {
47    use super::{RustlsAcceptor, RustlsConfig, Server, SocketAddr};
48
49    /// Creates a TLS server that binds to the provided address using the rustls library.
50    #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
51    pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
52        super::bind_rustls(addr, config)
53    }
54
55    /// Creates a TLS server from an existing `std::net::TcpListener` using the rustls library.
56    #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
57    pub fn from_tcp_rustls(
58        listener: std::net::TcpListener,
59        config: RustlsConfig,
60    ) -> Server<RustlsAcceptor> {
61        let acceptor = RustlsAcceptor::new(config);
62
63        Server::from_tcp(listener).acceptor(acceptor)
64    }
65}
66
67pub mod future;
68
69/// Helper function to create a TLS server bound to a provided address.
70pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
71    let acceptor = RustlsAcceptor::new(config);
72
73    Server::bind(addr).acceptor(acceptor)
74}
75
76/// Helper function to create a TLS server from an existing `std::net::TcpListener`.
77pub fn from_tcp_rustls(
78    listener: std::net::TcpListener,
79    config: RustlsConfig,
80) -> Server<RustlsAcceptor> {
81    let acceptor = RustlsAcceptor::new(config);
82
83    Server::from_tcp(listener).acceptor(acceptor)
84}
85
86/// A TLS acceptor implementation using the rustls library.
87#[derive(Clone)]
88pub struct RustlsAcceptor<A = DefaultAcceptor> {
89    inner: A,
90    config: RustlsConfig,
91    handshake_timeout: Duration,
92}
93
94impl RustlsAcceptor {
95    /// Constructs a new rustls acceptor with the given configuration.
96    pub fn new(config: RustlsConfig) -> Self {
97        let inner = DefaultAcceptor::new();
98
99        // Default handshake timeout is set to 10 seconds.
100        // In test mode, this is reduced to 1 second to avoid waiting too long.
101        #[cfg(not(test))]
102        let handshake_timeout = Duration::from_secs(10);
103        #[cfg(test)]
104        let handshake_timeout = Duration::from_secs(1);
105
106        Self {
107            inner,
108            config,
109            handshake_timeout,
110        }
111    }
112
113    /// Allows overriding the default TLS handshake timeout.
114    pub fn handshake_timeout(mut self, val: Duration) -> Self {
115        self.handshake_timeout = val;
116        self
117    }
118}
119
120impl<A> RustlsAcceptor<A> {
121    /// Replaces the inner acceptor with a custom acceptor.
122    pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> RustlsAcceptor<Acceptor> {
123        RustlsAcceptor {
124            inner: acceptor,
125            config: self.config,
126            handshake_timeout: self.handshake_timeout,
127        }
128    }
129}
130
131// Implementation to accept incoming TLS connections using rustls.
132impl<A, I, S> Accept<I, S> for RustlsAcceptor<A>
133where
134    A: Accept<I, S>,
135    A::Stream: AsyncRead + AsyncWrite + Unpin,
136{
137    type Stream = TlsStream<A::Stream>;
138    type Service = A::Service;
139    type Future = RustlsAcceptorFuture<A::Future, A::Stream, A::Service>;
140
141    fn accept(&self, stream: I, service: S) -> Self::Future {
142        let inner_future = self.inner.accept(stream, service);
143        let config = self.config.clone();
144
145        RustlsAcceptorFuture::new(inner_future, config, self.handshake_timeout)
146    }
147}
148
149impl<A> fmt::Debug for RustlsAcceptor<A> {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        f.debug_struct("RustlsAcceptor").finish()
152    }
153}
154
155/// Represents the rustls configuration for the server.
156#[derive(Clone)]
157pub struct RustlsConfig {
158    inner: Arc<ArcSwap<ServerConfig>>,
159}
160
161// The `RustlsConfig` structure represents configuration data for rustls.
162impl RustlsConfig {
163    /// Create a new `RustlsConfig` from an `Arc<ServerConfig>`.
164    ///
165    /// Important: This method does not set ALPN protocols (like `http/1.1` or `h2`) automatically.
166    /// ALPN protocols need to be set manually when using this method.
167    pub fn from_config(config: Arc<ServerConfig>) -> Self {
168        let inner = Arc::new(ArcSwap::new(config));
169        Self { inner }
170    }
171
172    /// Create a `RustlsConfig` from DER-encoded data.
173    /// DER is a binary format for encoding data, commonly used for certificates and keys.
174    ///
175    /// `cert` is expected to be a DER-encoded X.509 certificate.
176    /// `key` is expected to be a DER-encoded ASN.1 format private key, either in PKCS#8 or PKCS#1 format.
177    pub async fn from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<Self> {
178        let server_config = spawn_blocking(|| config_from_der(cert, key))
179            .await
180            .unwrap()?;
181        let inner = Arc::new(ArcSwap::from_pointee(server_config));
182        Ok(Self { inner })
183    }
184
185    /// Create a `RustlsConfig` from PEM-formatted data.
186    /// PEM is a text-based format used to encode binary data like certificates and keys.
187    ///
188    /// Both `cert` and `key` must be provided in PEM format.
189    pub async fn from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<Self> {
190        let server_config = spawn_blocking(|| config_from_pem(cert, key))
191            .await
192            .unwrap()?;
193        let inner = Arc::new(ArcSwap::from_pointee(server_config));
194        Ok(Self { inner })
195    }
196
197    /// Create a `RustlsConfig` by reading PEM-formatted files.
198    ///
199    /// The contents of the provided certificate and private key files must be in PEM format.
200    pub async fn from_pem_file(cert: impl AsRef<Path>, key: impl AsRef<Path>) -> io::Result<Self> {
201        let server_config = config_from_pem_file(cert, key).await?;
202        let inner = Arc::new(ArcSwap::from_pointee(server_config));
203        Ok(Self { inner })
204    }
205
206    /// Retrieve the inner `Arc<ServerConfig>` from the `RustlsConfig`.
207    pub fn get_inner(&self) -> Arc<ServerConfig> {
208        self.inner.load_full()
209    }
210
211    /// Update (or reload) the `RustlsConfig` with a new `Arc<ServerConfig>`.
212    pub fn reload_from_config(&self, config: Arc<ServerConfig>) {
213        self.inner.store(config);
214    }
215
216    /// Reload the `RustlsConfig` from provided DER-encoded data.
217    ///
218    /// As with the `from_der` method, `cert` must be DER-encoded X.509 and `key`
219    /// should be in either PKCS#8 or PKCS#1 DER-encoded ASN.1 format.
220    pub async fn reload_from_der(&self, cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<()> {
221        let server_config = spawn_blocking(|| config_from_der(cert, key))
222            .await
223            .unwrap()?;
224        let inner = Arc::new(server_config);
225        self.inner.store(inner);
226        Ok(())
227    }
228
229    /// Reload the `RustlsConfig` using provided PEM-formatted data.
230    pub async fn reload_from_pem(&self, cert: Vec<u8>, key: Vec<u8>) -> io::Result<()> {
231        let server_config = spawn_blocking(|| config_from_pem(cert, key))
232            .await
233            .unwrap()?;
234        let inner = Arc::new(server_config);
235        self.inner.store(inner);
236        Ok(())
237    }
238
239    /// Reload the `RustlsConfig` from provided PEM-formatted files.
240    pub async fn reload_from_pem_file(
241        &self,
242        cert: impl AsRef<Path>,
243        key: impl AsRef<Path>,
244    ) -> io::Result<()> {
245        let server_config = config_from_pem_file(cert, key).await?;
246        let inner = Arc::new(server_config);
247        self.inner.store(inner);
248        Ok(())
249    }
250}
251
252// This provides a debug representation for the `RustlsConfig`.
253impl fmt::Debug for RustlsConfig {
254    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255        f.debug_struct("RustlsConfig").finish()
256    }
257}
258
259// Helper function to convert DER-encoded certificate and key into rustls's `ServerConfig`.
260fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig> {
261    // Convert the raw bytes into rustls's Certificate and PrivateKey structures.
262    let cert = cert.into_iter().map(Certificate).collect();
263    let key = PrivateKey(key);
264
265    // Construct the ServerConfig.
266    let mut config = ServerConfig::builder()
267        .with_safe_defaults()
268        .with_no_client_auth()
269        .with_single_cert(cert, key)
270        .map_err(io_other)?;
271
272    // Set ALPN protocols for the configuration.
273    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
274
275    Ok(config)
276}
277
278// Helper function to convert PEM-formatted certificate and key into rustls' `ServerConfig`.
279fn config_from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<ServerConfig> {
280    use rustls_pemfile::Item;
281
282    // Parse PEM formatted data into rustls structures.
283    let cert = rustls_pemfile::certs(&mut cert.as_ref())?;
284    let key = match rustls_pemfile::read_one(&mut key.as_ref())? {
285        Some(Item::RSAKey(key)) | Some(Item::PKCS8Key(key)) | Some(Item::ECKey(key)) => key,
286        _ => return Err(io_other("private key format not supported")),
287    };
288
289    config_from_der(cert, key)
290}
291
292// Helper function to read PEM-formatted files and convert them into rustls' ServerConfig.
293async fn config_from_pem_file(
294    cert: impl AsRef<Path>,
295    key: impl AsRef<Path>,
296) -> io::Result<ServerConfig> {
297    // Read the PEM files asynchronously.
298    let cert = tokio::fs::read(cert.as_ref()).await?;
299    let key = tokio::fs::read(key.as_ref()).await?;
300
301    config_from_pem(cert, key)
302}
303
304#[cfg(test)]
305pub(crate) mod tests {
306    use crate::{
307        handle::Handle,
308        tls_rustls::{self, RustlsConfig},
309    };
310    use axum::{routing::get, Router};
311    use bytes::Bytes;
312    use http::{response, Request};
313    use hyper::{
314        client::conn::{handshake, SendRequest},
315        Body,
316    };
317    use rustls::{
318        client::{ServerCertVerified, ServerCertVerifier},
319        Certificate, ClientConfig, ServerName,
320    };
321    use std::{
322        convert::TryFrom,
323        io,
324        net::SocketAddr,
325        sync::Arc,
326        time::{Duration, SystemTime},
327    };
328    use tokio::time::sleep;
329    use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
330    use tokio_rustls::TlsConnector;
331    use tower::{Service, ServiceExt};
332
333    #[tokio::test]
334    async fn start_and_request() {
335        let (_handle, _server_task, addr) = start_server().await;
336
337        let (mut client, _conn) = connect(addr).await;
338
339        let (_parts, body) = send_empty_request(&mut client).await;
340
341        assert_eq!(body.as_ref(), b"Hello, world!");
342    }
343
344    #[tokio::test]
345    async fn tls_timeout() {
346        let (handle, _server_task, addr) = start_server().await;
347        assert_eq!(handle.connection_count(), 0);
348
349        // We intentionally avoid driving a TLS handshake to completion.
350        let _stream = TcpStream::connect(addr).await.unwrap();
351
352        sleep(Duration::from_millis(500)).await;
353        assert_eq!(handle.connection_count(), 1);
354
355        tokio::time::sleep(Duration::from_millis(1000)).await;
356        // Timeout defaults to 1s during testing, and we have waited 1.5 seconds.
357        assert_eq!(handle.connection_count(), 0);
358    }
359
360    #[tokio::test]
361    async fn test_reload() {
362        let handle = Handle::new();
363
364        let config = RustlsConfig::from_pem_file(
365            "examples/self-signed-certs/cert.pem",
366            "examples/self-signed-certs/key.pem",
367        )
368        .await
369        .unwrap();
370
371        let server_handle = handle.clone();
372        let rustls_config = config.clone();
373        tokio::spawn(async move {
374            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
375
376            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
377
378            tls_rustls::bind_rustls(addr, rustls_config)
379                .handle(server_handle)
380                .serve(app.into_make_service())
381                .await
382        });
383
384        let addr = handle.listening().await.unwrap();
385
386        let cert_a = get_first_cert(addr).await;
387        let mut cert_b = get_first_cert(addr).await;
388
389        assert_eq!(cert_a, cert_b);
390
391        config
392            .reload_from_pem_file(
393                "examples/self-signed-certs/reload/cert.pem",
394                "examples/self-signed-certs/reload/key.pem",
395            )
396            .await
397            .unwrap();
398
399        cert_b = get_first_cert(addr).await;
400
401        assert_ne!(cert_a, cert_b);
402
403        config
404            .reload_from_pem_file(
405                "examples/self-signed-certs/cert.pem",
406                "examples/self-signed-certs/key.pem",
407            )
408            .await
409            .unwrap();
410
411        cert_b = get_first_cert(addr).await;
412
413        assert_eq!(cert_a, cert_b);
414    }
415
416    #[tokio::test]
417    async fn test_shutdown() {
418        let (handle, _server_task, addr) = start_server().await;
419
420        let (mut client, conn) = connect(addr).await;
421
422        handle.shutdown();
423
424        let response_future_result = client
425            .ready()
426            .await
427            .unwrap()
428            .call(Request::new(Body::empty()))
429            .await;
430
431        assert!(response_future_result.is_err());
432
433        // Connection task should finish soon.
434        let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
435    }
436
437    #[tokio::test]
438    async fn test_graceful_shutdown() {
439        let (handle, server_task, addr) = start_server().await;
440
441        let (mut client, conn) = connect(addr).await;
442
443        handle.graceful_shutdown(None);
444
445        let (_parts, body) = send_empty_request(&mut client).await;
446
447        assert_eq!(body.as_ref(), b"Hello, world!");
448
449        // Disconnect client.
450        conn.abort();
451
452        // Server task should finish soon.
453        let server_result = timeout(Duration::from_secs(1), server_task)
454            .await
455            .unwrap()
456            .unwrap();
457
458        assert!(server_result.is_ok());
459    }
460
461    #[tokio::test]
462    async fn test_graceful_shutdown_timed() {
463        let (handle, server_task, addr) = start_server().await;
464
465        let (mut client, _conn) = connect(addr).await;
466
467        handle.graceful_shutdown(Some(Duration::from_millis(250)));
468
469        let (_parts, body) = send_empty_request(&mut client).await;
470
471        assert_eq!(body.as_ref(), b"Hello, world!");
472
473        // Don't disconnect client.
474        // conn.abort();
475
476        // Server task should finish soon.
477        let server_result = timeout(Duration::from_secs(1), server_task)
478            .await
479            .unwrap()
480            .unwrap();
481
482        assert!(server_result.is_ok());
483    }
484
485    async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
486        let handle = Handle::new();
487
488        let server_handle = handle.clone();
489        let server_task = tokio::spawn(async move {
490            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
491
492            let config = RustlsConfig::from_pem_file(
493                "examples/self-signed-certs/cert.pem",
494                "examples/self-signed-certs/key.pem",
495            )
496            .await?;
497
498            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
499
500            tls_rustls::bind_rustls(addr, config)
501                .handle(server_handle)
502                .serve(app.into_make_service())
503                .await
504        });
505
506        let addr = handle.listening().await.unwrap();
507
508        (handle, server_task, addr)
509    }
510
511    async fn get_first_cert(addr: SocketAddr) -> Certificate {
512        let stream = TcpStream::connect(addr).await.unwrap();
513        let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();
514
515        let (_io, client_connection) = tls_stream.into_inner();
516
517        client_connection.peer_certificates().unwrap()[0].clone()
518    }
519
520    async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
521        let stream = TcpStream::connect(addr).await.unwrap();
522        let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();
523
524        let (send_request, connection) = handshake(tls_stream).await.unwrap();
525
526        let task = tokio::spawn(async move {
527            let _ = connection.await;
528        });
529
530        (send_request, task)
531    }
532
533    async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
534        let (parts, body) = client
535            .ready()
536            .await
537            .unwrap()
538            .call(Request::new(Body::empty()))
539            .await
540            .unwrap()
541            .into_parts();
542        let body = hyper::body::to_bytes(body).await.unwrap();
543
544        (parts, body)
545    }
546
547    /// Used in `proxy-protocol` feature tests.
548    pub(crate) fn tls_connector() -> TlsConnector {
549        struct NoVerify;
550
551        impl ServerCertVerifier for NoVerify {
552            fn verify_server_cert(
553                &self,
554                _end_entity: &Certificate,
555                _intermediates: &[Certificate],
556                _server_name: &ServerName,
557                _scts: &mut dyn Iterator<Item = &[u8]>,
558                _ocsp_response: &[u8],
559                _now: SystemTime,
560            ) -> Result<ServerCertVerified, rustls::Error> {
561                Ok(ServerCertVerified::assertion())
562            }
563        }
564
565        let mut client_config = ClientConfig::builder()
566            .with_safe_defaults()
567            .with_custom_certificate_verifier(Arc::new(NoVerify))
568            .with_no_client_auth();
569
570        client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
571
572        TlsConnector::from(Arc::new(client_config))
573    }
574
575    /// Used in `proxy-protocol` feature tests.
576    pub(crate) fn dns_name() -> ServerName {
577        ServerName::try_from("localhost").unwrap()
578    }
579}