hyper_serve/tls_rustls/
mod.rs

1//! Tls implementation using [`rustls`].
2//!
3//! # Example
4//!
5//! ```rust,no_run
6//! use axum::{routing::get, Router};
7//! use hyper_serve::tls_rustls::RustlsConfig;
8//! use std::net::SocketAddr;
9//!
10//! #[tokio::main]
11//! async fn main() {
12//!     let app = Router::new().route("/", get(|| async { "Hello, world!" }));
13//!
14//!     let config = RustlsConfig::from_pem_file(
15//!         "examples/self-signed-certs/cert.pem",
16//!         "examples/self-signed-certs/key.pem",
17//!     )
18//!     .await
19//!     .unwrap();
20//!
21//!     let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
22//!     println!("listening on {}", addr);
23//!     hyper_serve::bind_rustls(addr, config)
24//!         .serve(app.into_make_service())
25//!         .await
26//!         .unwrap();
27//! }
28//! ```
29
30use self::future::RustlsAcceptorFuture;
31use crate::{
32    accept::{Accept, DefaultAcceptor},
33    server::{io_other, Server},
34};
35use arc_swap::ArcSwap;
36use rustls::ServerConfig;
37use rustls_pemfile::Item;
38use rustls_pki_types::{CertificateDer, PrivateKeyDer};
39use std::time::Duration;
40use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc};
41use tokio::{
42    io::{AsyncRead, AsyncWrite},
43    task::spawn_blocking,
44};
45use tokio_rustls::server::TlsStream;
46
47pub(crate) mod export {
48    use super::*;
49
50    /// Create a tls server that will bind to provided address.
51    #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
52    pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
53        super::bind_rustls(addr, config)
54    }
55
56    /// Create a tls server from existing `std::net::TcpListener`.
57    #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
58    pub fn from_tcp_rustls(
59        listener: std::net::TcpListener,
60        config: RustlsConfig,
61    ) -> Server<RustlsAcceptor> {
62        let acceptor = RustlsAcceptor::new(config);
63
64        Server::from_tcp(listener).acceptor(acceptor)
65    }
66}
67
68pub mod future;
69
70/// Create a tls server that will bind to provided address.
71pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
72    let acceptor = RustlsAcceptor::new(config);
73
74    Server::bind(addr).acceptor(acceptor)
75}
76
77/// Create a tls server from existing `std::net::TcpListener`.
78pub fn from_tcp_rustls(
79    listener: std::net::TcpListener,
80    config: RustlsConfig,
81) -> Server<RustlsAcceptor> {
82    let acceptor = RustlsAcceptor::new(config);
83
84    Server::from_tcp(listener).acceptor(acceptor)
85}
86
87/// Tls acceptor using rustls.
88#[derive(Clone)]
89pub struct RustlsAcceptor<A = DefaultAcceptor> {
90    inner: A,
91    config: RustlsConfig,
92    handshake_timeout: Duration,
93}
94
95impl RustlsAcceptor {
96    /// Create a new rustls acceptor.
97    pub fn new(config: RustlsConfig) -> Self {
98        let inner = DefaultAcceptor::new();
99
100        #[cfg(not(test))]
101        let handshake_timeout = Duration::from_secs(10);
102
103        // Don't force tests to wait too long.
104        #[cfg(test)]
105        let handshake_timeout = Duration::from_secs(1);
106
107        Self {
108            inner,
109            config,
110            handshake_timeout,
111        }
112    }
113
114    /// Override the default TLS handshake timeout of 10 seconds, except during testing.
115    pub fn handshake_timeout(mut self, val: Duration) -> Self {
116        self.handshake_timeout = val;
117        self
118    }
119}
120
121impl<A> RustlsAcceptor<A> {
122    /// Overwrite inner acceptor.
123    pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> RustlsAcceptor<Acceptor> {
124        RustlsAcceptor {
125            inner: acceptor,
126            config: self.config,
127            handshake_timeout: self.handshake_timeout,
128        }
129    }
130}
131
132impl<A, I, S> Accept<I, S> for RustlsAcceptor<A>
133where
134    A: Accept<I, S>,
135    A::Stream: AsyncRead + AsyncWrite + Unpin,
136{
137    type Stream = TlsStream<A::Stream>;
138    type Service = A::Service;
139    type Future = RustlsAcceptorFuture<A::Future, A::Stream, A::Service>;
140
141    fn accept(&self, stream: I, service: S) -> Self::Future {
142        let inner_future = self.inner.accept(stream, service);
143        let config = self.config.clone();
144
145        RustlsAcceptorFuture::new(inner_future, config, self.handshake_timeout)
146    }
147}
148
149impl<A> fmt::Debug for RustlsAcceptor<A> {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        f.debug_struct("RustlsAcceptor").finish()
152    }
153}
154
155/// Rustls configuration.
156#[derive(Clone)]
157pub struct RustlsConfig {
158    inner: Arc<ArcSwap<ServerConfig>>,
159}
160
161impl RustlsConfig {
162    /// Create config from `Arc<`[`ServerConfig`]`>`.
163    ///
164    /// NOTE: You need to set ALPN protocols (like `http/1.1` or `h2`) manually.
165    pub fn from_config(config: Arc<ServerConfig>) -> Self {
166        let inner = Arc::new(ArcSwap::new(config));
167
168        Self { inner }
169    }
170
171    /// Create config from DER-encoded data.
172    ///
173    /// The certificate must be DER-encoded X.509.
174    ///
175    /// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
176    pub async fn from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<Self> {
177        let server_config = spawn_blocking(|| config_from_der(cert, key))
178            .await
179            .unwrap()?;
180        let inner = Arc::new(ArcSwap::from_pointee(server_config));
181
182        Ok(Self { inner })
183    }
184
185    /// Create config from PEM formatted data.
186    ///
187    /// Certificate and private key must be in PEM format.
188    pub async fn from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<Self> {
189        let server_config = spawn_blocking(|| config_from_pem(cert, key))
190            .await
191            .unwrap()?;
192        let inner = Arc::new(ArcSwap::from_pointee(server_config));
193
194        Ok(Self { inner })
195    }
196
197    /// Create config from PEM formatted files.
198    ///
199    /// Contents of certificate file and private key file must be in PEM format.
200    pub async fn from_pem_file(cert: impl AsRef<Path>, key: impl AsRef<Path>) -> io::Result<Self> {
201        let server_config = config_from_pem_file(cert, key).await?;
202        let inner = Arc::new(ArcSwap::from_pointee(server_config));
203
204        Ok(Self { inner })
205    }
206
207    /// Get  inner `Arc<`[`ServerConfig`]`>`.
208    pub fn get_inner(&self) -> Arc<ServerConfig> {
209        self.inner.load_full()
210    }
211
212    /// Reload config from `Arc<`[`ServerConfig`]`>`.
213    pub fn reload_from_config(&self, config: Arc<ServerConfig>) {
214        self.inner.store(config);
215    }
216
217    /// Reload config from DER-encoded data.
218    ///
219    /// The certificate must be DER-encoded X.509.
220    ///
221    /// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
222    pub async fn reload_from_der(&self, cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<()> {
223        let server_config = spawn_blocking(|| config_from_der(cert, key))
224            .await
225            .unwrap()?;
226        let inner = Arc::new(server_config);
227
228        self.inner.store(inner);
229
230        Ok(())
231    }
232
233    /// This helper will establish a TLS server based on strong cipher suites
234    /// from a PEM-formatted certificate chain and key.
235    pub async fn from_pem_chain_file(
236        chain: impl AsRef<Path>,
237        key: impl AsRef<Path>,
238    ) -> io::Result<Self> {
239        let server_config = config_from_pem_chain_file(chain, key).await?;
240        let inner = Arc::new(ArcSwap::from_pointee(server_config));
241
242        Ok(Self { inner })
243    }
244
245    /// Reload config from PEM formatted data.
246    ///
247    /// Certificate and private key must be in PEM format.
248    pub async fn reload_from_pem(&self, cert: Vec<u8>, key: Vec<u8>) -> io::Result<()> {
249        let server_config = spawn_blocking(|| config_from_pem(cert, key))
250            .await
251            .unwrap()?;
252        let inner = Arc::new(server_config);
253
254        self.inner.store(inner);
255
256        Ok(())
257    }
258
259    /// Reload config from PEM formatted files.
260    ///
261    /// Contents of certificate file and private key file must be in PEM format.
262    pub async fn reload_from_pem_file(
263        &self,
264        cert: impl AsRef<Path>,
265        key: impl AsRef<Path>,
266    ) -> io::Result<()> {
267        let server_config = config_from_pem_file(cert, key).await?;
268        let inner = Arc::new(server_config);
269
270        self.inner.store(inner);
271
272        Ok(())
273    }
274}
275
276impl fmt::Debug for RustlsConfig {
277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278        f.debug_struct("RustlsConfig").finish()
279    }
280}
281
282fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig> {
283    let cert = cert.into_iter().map(CertificateDer::from).collect();
284    let key = PrivateKeyDer::try_from(key).map_err(io_other)?;
285
286    let mut config = ServerConfig::builder()
287        .with_no_client_auth()
288        .with_single_cert(cert, key)
289        .map_err(io_other)?;
290
291    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
292
293    Ok(config)
294}
295
296fn config_from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<ServerConfig> {
297    let cert = rustls_pemfile::certs(&mut cert.as_ref())
298        .map(|it| it.map(|it| it.to_vec()))
299        .collect::<Result<Vec<_>, _>>()?;
300    // Check the entire PEM file for the key in case it is not first section
301    let mut key_vec: Vec<Vec<u8>> = rustls_pemfile::read_all(&mut key.as_ref())
302        .filter_map(|i| match i.ok()? {
303            Item::Sec1Key(key) => Some(key.secret_sec1_der().to_vec()),
304            Item::Pkcs1Key(key) => Some(key.secret_pkcs1_der().to_vec().into()),
305            Item::Pkcs8Key(key) => Some(key.secret_pkcs8_der().to_vec().into()),
306            _ => None,
307        })
308        .collect();
309
310    // Make sure file contains only one key
311    if key_vec.len() != 1 {
312        return Err(io_other("private key format not supported"));
313    }
314
315    config_from_der(cert, key_vec.pop().unwrap())
316}
317
318async fn config_from_pem_file(
319    cert: impl AsRef<Path>,
320    key: impl AsRef<Path>,
321) -> io::Result<ServerConfig> {
322    let cert = tokio::fs::read(cert.as_ref()).await?;
323    let key = tokio::fs::read(key.as_ref()).await?;
324
325    config_from_pem(cert, key)
326}
327
328async fn config_from_pem_chain_file(
329    cert: impl AsRef<Path>,
330    chain: impl AsRef<Path>,
331) -> io::Result<ServerConfig> {
332    let cert = tokio::fs::read(cert.as_ref()).await?;
333    let cert = rustls_pemfile::certs(&mut cert.as_ref())
334        .map(|it| it.map(|it| CertificateDer::from(it.to_vec())))
335        .collect::<Result<Vec<_>, _>>()?;
336    let key = tokio::fs::read(chain.as_ref()).await?;
337    let key_cert: PrivateKeyDer = match rustls_pemfile::read_one(&mut key.as_ref())?
338        .ok_or_else(|| io_other("could not parse pem file"))?
339    {
340        Item::Pkcs8Key(key) => {
341            Ok(PrivateKeyDer::try_from(key.secret_pkcs8_der().to_vec()).map_err(io_other)?)
342        }
343        x => Err(io_other(format!(
344            "invalid certificate format, received: {x:?}"
345        ))),
346    }?;
347
348    ServerConfig::builder()
349        .with_no_client_auth()
350        .with_single_cert(cert, key_cert)
351        .map_err(|_| io_other("invalid certificate"))
352}
353
354#[cfg(test)]
355mod tests {
356    use crate::handle::Handle;
357    use crate::server::tests::slow_body;
358    use crate::tls_rustls::{self, RustlsConfig};
359    use axum::body::Body;
360    use axum::response::Response;
361    use axum::routing::{get, post};
362    use axum::Router;
363    use bytes::Bytes;
364    use http::{response, Request};
365    use http_body_util::BodyExt;
366    use hyper::client::conn::http1::{handshake, SendRequest};
367    use hyper_util::rt::TokioIo;
368    use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
369    use rustls::{ClientConfig, DigitallySignedStruct, Error, SignatureScheme};
370    use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
371    use std::fmt::Debug;
372    use std::{convert::TryFrom, io, net::SocketAddr, sync::Arc, time::Duration};
373    use tokio::sync::oneshot;
374    use tokio::time::sleep;
375    use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
376    use tokio_rustls::TlsConnector;
377
378    #[tokio::test]
379    async fn start_and_request() {
380        let (_handle, _server_task, addr) = start_server().await;
381
382        let (mut client, _conn) = connect(addr).await;
383
384        let (_parts, body) = send_empty_request(&mut client).await;
385
386        assert_eq!(body.as_ref(), b"Hello, world!");
387    }
388
389    #[ignore]
390    #[tokio::test]
391    async fn tls_timeout() {
392        let (handle, _server_task, addr) = start_server().await;
393        assert_eq!(handle.connection_count(), 0);
394
395        // We intentionally avoid driving a TLS handshake to completion.
396        let _stream = TcpStream::connect(addr).await.unwrap();
397
398        sleep(Duration::from_millis(500)).await;
399        assert_eq!(handle.connection_count(), 1);
400
401        tokio::time::sleep(Duration::from_millis(1000)).await;
402        // Timeout defaults to 1s during testing, and we have waited 1.5 seconds.
403        assert_eq!(handle.connection_count(), 0);
404    }
405
406    #[tokio::test]
407    async fn test_reload() {
408        let handle = Handle::new();
409
410        let config = RustlsConfig::from_pem_file(
411            "examples/self-signed-certs/cert.pem",
412            "examples/self-signed-certs/key.pem",
413        )
414        .await
415        .unwrap();
416
417        let server_handle = handle.clone();
418        let rustls_config = config.clone();
419        tokio::spawn(async move {
420            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
421
422            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
423
424            tls_rustls::bind_rustls(addr, rustls_config)
425                .handle(server_handle)
426                .serve(app.into_make_service())
427                .await
428        });
429
430        let addr = handle.listening().await.unwrap();
431
432        let cert_a = get_first_cert(addr).await;
433        let mut cert_b = get_first_cert(addr).await;
434
435        assert_eq!(cert_a, cert_b);
436
437        config
438            .reload_from_pem_file(
439                "examples/self-signed-certs/reload/cert.pem",
440                "examples/self-signed-certs/reload/key.pem",
441            )
442            .await
443            .unwrap();
444
445        cert_b = get_first_cert(addr).await;
446
447        assert_ne!(cert_a, cert_b);
448
449        config
450            .reload_from_pem_file(
451                "examples/self-signed-certs/cert.pem",
452                "examples/self-signed-certs/key.pem",
453            )
454            .await
455            .unwrap();
456
457        cert_b = get_first_cert(addr).await;
458
459        assert_eq!(cert_a, cert_b);
460    }
461
462    #[tokio::test]
463    async fn test_shutdown() {
464        let (handle, _server_task, addr) = start_server().await;
465
466        let (mut client, conn) = connect(addr).await;
467
468        handle.shutdown();
469
470        let response_future_result = client.send_request(Request::new(Body::empty())).await;
471
472        assert!(response_future_result.is_err());
473
474        // Connection task should finish soon.
475        let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
476    }
477
478    #[tokio::test]
479    async fn test_graceful_shutdown_timeout() {
480        let (handle, server_task, addr) = start_server().await;
481
482        let (mut client1, _conn1) = connect(addr).await;
483        let (mut client2, _conn2) = connect(addr).await;
484
485        // Clients can send request before graceful shutdown.
486        crate::server::tests::do_empty_request(&mut client1)
487            .await
488            .unwrap();
489        crate::server::tests::do_empty_request(&mut client2)
490            .await
491            .unwrap();
492
493        let start = tokio::time::Instant::now();
494
495        let (hdr1_tx, hdr1_rx) = oneshot::channel::<()>();
496
497        let task1 = async {
498            // A slow request made before graceful shutdown is handled.
499            // This one is shorter than the timeout, so it should succeed.
500            let hdr1 =
501                crate::server::tests::send_slow_request(&mut client1, Duration::from_millis(222))
502                    .await;
503            hdr1_tx.send(()).unwrap();
504
505            let res1 = crate::server::tests::recv_slow_response_body(hdr1.unwrap()).await;
506            res1.unwrap();
507        };
508        let task2 = async {
509            // A slow request made before graceful shutdown is handled.
510            // This one is much longer than the timeout; it should fail sometime
511            // after the graceful shutdown timeout.
512
513            let hdr2 =
514                crate::server::tests::send_slow_request(&mut client2, Duration::from_millis(5_555))
515                    .await;
516            hdr2.unwrap_err();
517        };
518        let task3 = async {
519            // Begin graceful shutdown after we receive response headers for (1).
520            hdr1_rx.await.unwrap();
521
522            // Set a timeout on requests to finish before we drop them.
523            handle.graceful_shutdown(Some(Duration::from_millis(333)));
524
525            // Server task should finish soon.
526            timeout(Duration::from_secs(1), server_task)
527                .await
528                .unwrap()
529                .unwrap()
530                .unwrap();
531
532            // At this point, graceful shutdown must have occured.
533            assert!(start.elapsed() >= Duration::from_millis(222 + 333));
534            assert!(start.elapsed() <= Duration::from_millis(5_555));
535        };
536
537        tokio::join!(task1, task2, task3);
538    }
539
540    async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
541        let handle = Handle::new();
542
543        let server_handle = handle.clone();
544        let server_task = tokio::spawn(async move {
545            let app = Router::new()
546                .route("/", get(|| async { "Hello, world!" }))
547                .route(
548                    "/echo_slowly",
549                    post(|body: Bytes| async move {
550                        // Stream a response slowly, byte-by-byte, over 100ms
551                        Response::new(slow_body(body.len(), Duration::from_millis(100)))
552                    }),
553                );
554
555            let config = RustlsConfig::from_pem_file(
556                "examples/self-signed-certs/cert.pem",
557                "examples/self-signed-certs/key.pem",
558            )
559            .await?;
560
561            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
562
563            tls_rustls::bind_rustls(addr, config)
564                .handle(server_handle)
565                .serve(app.into_make_service())
566                .await
567        });
568
569        let addr = handle.listening().await.unwrap();
570
571        (handle, server_task, addr)
572    }
573
574    async fn get_first_cert(addr: SocketAddr) -> CertificateDer<'static> {
575        let stream = TcpStream::connect(addr).await.unwrap();
576        let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();
577
578        let (_io, client_connection) = tls_stream.into_inner();
579
580        client_connection.peer_certificates().unwrap()[0].clone()
581    }
582
583    async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
584        let stream = TcpStream::connect(addr).await.unwrap();
585        let tls_stream = TokioIo::new(tls_connector().connect(dns_name(), stream).await.unwrap());
586
587        let (send_request, connection) = handshake(tls_stream).await.unwrap();
588
589        let task = tokio::spawn(async move {
590            let _ = connection.await;
591        });
592
593        (send_request, task)
594    }
595
596    async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
597        let (parts, body) = client
598            .send_request(Request::new(Body::empty()))
599            .await
600            .unwrap()
601            .into_parts();
602        let body = body.collect().await.unwrap().to_bytes();
603
604        (parts, body)
605    }
606
607    fn tls_connector() -> TlsConnector {
608        #[derive(Debug)]
609        struct NoVerify;
610
611        impl ServerCertVerifier for NoVerify {
612            fn verify_server_cert(
613                &self,
614                _end_entity: &CertificateDer,
615                _intermediates: &[CertificateDer],
616                _server_name: &ServerName,
617                _ocsp_response: &[u8],
618                _now: UnixTime,
619            ) -> Result<ServerCertVerified, rustls::Error> {
620                Ok(ServerCertVerified::assertion())
621            }
622
623            fn verify_tls12_signature(
624                &self,
625                _message: &[u8],
626                _cert: &CertificateDer<'_>,
627                _dss: &DigitallySignedStruct,
628            ) -> Result<HandshakeSignatureValid, Error> {
629                Ok(HandshakeSignatureValid::assertion())
630            }
631
632            fn verify_tls13_signature(
633                &self,
634                _message: &[u8],
635                _cert: &CertificateDer<'_>,
636                _dss: &DigitallySignedStruct,
637            ) -> Result<HandshakeSignatureValid, Error> {
638                Ok(HandshakeSignatureValid::assertion())
639            }
640
641            fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
642                vec![
643                    SignatureScheme::RSA_PKCS1_SHA1,
644                    SignatureScheme::RSA_PKCS1_SHA256,
645                    SignatureScheme::RSA_PKCS1_SHA384,
646                    SignatureScheme::RSA_PKCS1_SHA512,
647                    SignatureScheme::RSA_PSS_SHA256,
648                    SignatureScheme::RSA_PSS_SHA384,
649                    SignatureScheme::RSA_PSS_SHA512,
650                ]
651            }
652        }
653
654        let mut client_config = ClientConfig::builder()
655            .dangerous()
656            .with_custom_certificate_verifier(Arc::new(NoVerify))
657            .with_no_client_auth();
658
659        client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
660
661        TlsConnector::from(Arc::new(client_config))
662    }
663
664    fn dns_name() -> ServerName<'static> {
665        ServerName::try_from("localhost").unwrap()
666    }
667}