axum_server/tls_rustls/
mod.rs

1//! Tls implementation using [`rustls`].
2//!
3//! # Example
4//!
5//! ```rust,no_run
6//! use axum::{routing::get, Router};
7//! use axum_server::tls_rustls::RustlsConfig;
8//! use std::net::SocketAddr;
9//!
10//! #[tokio::main]
11//! async fn main() {
12//!     let app = Router::new().route("/", get(|| async { "Hello, world!" }));
13//!
14//!     let config = RustlsConfig::from_pem_file(
15//!         "examples/self-signed-certs/cert.pem",
16//!         "examples/self-signed-certs/key.pem",
17//!     )
18//!     .await
19//!     .unwrap();
20//!
21//!     let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
22//!     println!("listening on {}", addr);
23//!     axum_server::bind_rustls(addr, config)
24//!         .serve(app.into_make_service())
25//!         .await
26//!         .unwrap();
27//! }
28//! ```
29
30use self::future::RustlsAcceptorFuture;
31use crate::{
32    accept::{Accept, DefaultAcceptor},
33    server::{io_other, Server},
34    Address,
35};
36use arc_swap::ArcSwap;
37use rustls::ServerConfig;
38use rustls_pki_types::pem::PemObject;
39use rustls_pki_types::{CertificateDer, PrivateKeyDer};
40use std::time::Duration;
41use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc};
42use tokio::{
43    io::{AsyncRead, AsyncWrite},
44    task::spawn_blocking,
45};
46use tokio_rustls::server::TlsStream;
47
48pub(crate) mod export {
49    #[allow(clippy::wildcard_imports)]
50    use super::*;
51
52    /// Create a tls server that will bind to provided address.
53    #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
54    pub fn bind_rustls<A: Address>(addr: A, config: RustlsConfig) -> Server<A, RustlsAcceptor> {
55        super::bind_rustls(addr, config)
56    }
57
58    /// Create a tls server from existing `std::net::TcpListener`.
59    #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
60    pub fn from_tcp_rustls(
61        listener: std::net::TcpListener,
62        config: RustlsConfig,
63    ) -> io::Result<Server<SocketAddr, RustlsAcceptor>> {
64        let acceptor = RustlsAcceptor::new(config);
65
66        Ok(crate::from_tcp(listener)?.acceptor(acceptor))
67    }
68
69    /// Create a tls server from existing `std::os::unix::net::UnixListener`.
70    #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
71    #[cfg(unix)]
72    pub fn from_unix_rustls(
73        listener: std::os::unix::net::UnixListener,
74        config: RustlsConfig,
75    ) -> io::Result<Server<std::os::unix::net::SocketAddr, RustlsAcceptor>> {
76        let acceptor = RustlsAcceptor::new(config);
77
78        Ok(crate::from_unix(listener)?.acceptor(acceptor))
79    }
80}
81
82pub mod future;
83
84/// Create a tls server that will bind to provided address.
85pub fn bind_rustls<A: Address>(addr: A, config: RustlsConfig) -> Server<A, RustlsAcceptor> {
86    let acceptor = RustlsAcceptor::new(config);
87
88    Server::bind(addr).acceptor(acceptor)
89}
90
91/// Create a tls server from existing `std::net::TcpListener`.
92pub fn from_tcp_rustls(
93    listener: std::net::TcpListener,
94    config: RustlsConfig,
95) -> io::Result<Server<SocketAddr, RustlsAcceptor>> {
96    let acceptor = RustlsAcceptor::new(config);
97
98    Ok(crate::from_tcp(listener)?.acceptor(acceptor))
99}
100
101/// Create a tls server from existing `std::os::unix::net::UnixListener`.
102#[cfg(unix)]
103pub fn from_unix_rustls(
104    listener: std::os::unix::net::UnixListener,
105    config: RustlsConfig,
106) -> io::Result<Server<std::os::unix::net::SocketAddr, RustlsAcceptor>> {
107    let acceptor = RustlsAcceptor::new(config);
108
109    Ok(crate::from_unix(listener)?.acceptor(acceptor))
110}
111
112/// Tls acceptor using rustls.
113#[derive(Clone)]
114pub struct RustlsAcceptor<A = DefaultAcceptor> {
115    inner: A,
116    config: RustlsConfig,
117    handshake_timeout: Duration,
118}
119
120impl RustlsAcceptor {
121    /// Create a new rustls acceptor.
122    pub fn new(config: RustlsConfig) -> Self {
123        let inner = DefaultAcceptor::new();
124
125        #[cfg(not(test))]
126        let handshake_timeout = Duration::from_secs(10);
127
128        // Don't force tests to wait too long.
129        #[cfg(test)]
130        let handshake_timeout = Duration::from_secs(1);
131
132        Self {
133            inner,
134            config,
135            handshake_timeout,
136        }
137    }
138
139    /// Override the default TLS handshake timeout of 10 seconds, except during testing.
140    pub fn handshake_timeout(mut self, val: Duration) -> Self {
141        self.handshake_timeout = val;
142        self
143    }
144}
145
146impl<A> RustlsAcceptor<A> {
147    /// Overwrite inner acceptor.
148    pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> RustlsAcceptor<Acceptor> {
149        RustlsAcceptor {
150            inner: acceptor,
151            config: self.config,
152            handshake_timeout: self.handshake_timeout,
153        }
154    }
155}
156
157impl<A, I, S> Accept<I, S> for RustlsAcceptor<A>
158where
159    A: Accept<I, S>,
160    A::Stream: AsyncRead + AsyncWrite + Unpin,
161{
162    type Stream = TlsStream<A::Stream>;
163    type Service = A::Service;
164    type Future = RustlsAcceptorFuture<A::Future, A::Stream, A::Service>;
165
166    fn accept(&self, stream: I, service: S) -> Self::Future {
167        let inner_future = self.inner.accept(stream, service);
168        let config = self.config.clone();
169
170        RustlsAcceptorFuture::new(inner_future, config, self.handshake_timeout)
171    }
172}
173
174impl<A> fmt::Debug for RustlsAcceptor<A> {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        f.debug_struct("RustlsAcceptor").finish()
177    }
178}
179
180/// Rustls configuration.
181#[derive(Clone)]
182pub struct RustlsConfig {
183    inner: Arc<ArcSwap<ServerConfig>>,
184}
185
186impl RustlsConfig {
187    /// Create config from `Arc<`[`ServerConfig`]`>`.
188    ///
189    /// NOTE: You need to set ALPN protocols (like `http/1.1` or `h2`) manually.
190    pub fn from_config(config: Arc<ServerConfig>) -> Self {
191        let inner = Arc::new(ArcSwap::new(config));
192
193        Self { inner }
194    }
195
196    /// Create config from DER-encoded data.
197    ///
198    /// The certificate must be DER-encoded X.509.
199    ///
200    /// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
201    pub async fn from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<Self> {
202        let server_config = spawn_blocking(|| config_from_der(cert, key))
203            .await
204            .unwrap()?;
205        let inner = Arc::new(ArcSwap::from_pointee(server_config));
206
207        Ok(Self { inner })
208    }
209
210    /// Create config from PEM formatted data.
211    ///
212    /// Certificate and private key must be in PEM format.
213    pub async fn from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<Self> {
214        let server_config = spawn_blocking(|| config_from_pem(cert, key))
215            .await
216            .unwrap()?;
217        let inner = Arc::new(ArcSwap::from_pointee(server_config));
218
219        Ok(Self { inner })
220    }
221
222    /// Create config from PEM formatted files.
223    ///
224    /// Contents of certificate file and private key file must be in PEM format.
225    pub async fn from_pem_file(cert: impl AsRef<Path>, key: impl AsRef<Path>) -> io::Result<Self> {
226        let server_config = config_from_pem_file(cert, key).await?;
227        let inner = Arc::new(ArcSwap::from_pointee(server_config));
228
229        Ok(Self { inner })
230    }
231
232    /// Get  inner `Arc<`[`ServerConfig`]`>`.
233    pub fn get_inner(&self) -> Arc<ServerConfig> {
234        self.inner.load_full()
235    }
236
237    /// Reload config from `Arc<`[`ServerConfig`]`>`.
238    pub fn reload_from_config(&self, config: Arc<ServerConfig>) {
239        self.inner.store(config);
240    }
241
242    /// Reload config from DER-encoded data.
243    ///
244    /// The certificate must be DER-encoded X.509.
245    ///
246    /// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
247    pub async fn reload_from_der(&self, cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<()> {
248        let server_config = spawn_blocking(|| config_from_der(cert, key))
249            .await
250            .unwrap()?;
251        let inner = Arc::new(server_config);
252
253        self.inner.store(inner);
254
255        Ok(())
256    }
257
258    /// This helper will establish a TLS server based on strong cipher suites
259    /// from a PEM-formatted certificate chain and key.
260    pub async fn from_pem_chain_file(
261        chain: impl AsRef<Path>,
262        key: impl AsRef<Path>,
263    ) -> io::Result<Self> {
264        let server_config = config_from_pem_chain_file(chain, key).await?;
265        let inner = Arc::new(ArcSwap::from_pointee(server_config));
266
267        Ok(Self { inner })
268    }
269
270    /// Reload config from PEM formatted data.
271    ///
272    /// Certificate and private key must be in PEM format.
273    pub async fn reload_from_pem(&self, cert: Vec<u8>, key: Vec<u8>) -> io::Result<()> {
274        let server_config = spawn_blocking(|| config_from_pem(cert, key))
275            .await
276            .unwrap()?;
277        let inner = Arc::new(server_config);
278
279        self.inner.store(inner);
280
281        Ok(())
282    }
283
284    /// Reload config from PEM formatted files.
285    ///
286    /// Contents of certificate file and private key file must be in PEM format.
287    pub async fn reload_from_pem_file(
288        &self,
289        cert: impl AsRef<Path>,
290        key: impl AsRef<Path>,
291    ) -> io::Result<()> {
292        let server_config = config_from_pem_file(cert, key).await?;
293        let inner = Arc::new(server_config);
294
295        self.inner.store(inner);
296
297        Ok(())
298    }
299}
300
301impl fmt::Debug for RustlsConfig {
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        f.debug_struct("RustlsConfig").finish()
304    }
305}
306
307fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig> {
308    let cert = cert.into_iter().map(CertificateDer::from).collect();
309    let key = PrivateKeyDer::try_from(key).map_err(io_other)?;
310
311    let mut config = ServerConfig::builder()
312        .with_no_client_auth()
313        .with_single_cert(cert, key)
314        .map_err(io_other)?;
315
316    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
317
318    Ok(config)
319}
320
321fn config_from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<ServerConfig> {
322    let cert: Vec<CertificateDer> = CertificateDer::pem_slice_iter(&cert)
323        .collect::<Result<Vec<_>, _>>()
324        .map_err(|_| io_other("failed to parse certificate"))?;
325
326    let mut key_result: Result<PrivateKeyDer, io::Error> =
327        Err(io_other("The private key file contained no keys"));
328
329    // Check the entire PEM file for the key in case it is not first section
330    for item in rustls_pki_types::pem::PemObject::pem_slice_iter(&key) {
331        let key: Result<PrivateKeyDer, io::Error> =
332            item.map_err(|_| io_other("failed to parse PEM"));
333
334        match key_result {
335            // if we already got a key, then...
336            Ok(_) => {
337                // ...if we get a key now, we know that there are multiple keys and that's not allowed
338                if key.is_ok() {
339                    return Err(io_other(
340                        "The private key file containsed multiple keys (it must only contain one)",
341                    ));
342                }
343            }
344            // but if already have an error, just overwrite it with whatever we got this time. If
345            // it's a good key, that's cool. If it's an error, then we're just ignoring the old
346            // error in favor of this new one
347            Err(_) => key_result = key,
348        }
349    }
350
351    let key = key_result?;
352    let cert_der: Vec<Vec<u8>> = cert.into_iter().map(|c| c.to_vec()).collect();
353    let key_der = key.secret_der().to_vec();
354
355    config_from_der(cert_der, key_der)
356}
357
358async fn config_from_pem_file(
359    cert: impl AsRef<Path>,
360    key: impl AsRef<Path>,
361) -> io::Result<ServerConfig> {
362    let cert = fs_err::tokio::read(cert.as_ref()).await?;
363    let key = fs_err::tokio::read(key.as_ref()).await?;
364
365    config_from_pem(cert, key)
366}
367
368async fn config_from_pem_chain_file(
369    cert: impl AsRef<Path>,
370    chain: impl AsRef<Path>,
371) -> io::Result<ServerConfig> {
372    let cert = fs_err::tokio::read(cert.as_ref()).await?;
373    let cert = CertificateDer::pem_slice_iter(&cert)
374        .collect::<Result<Vec<_>, _>>()
375        .map_err(|_| io_other("failed to parse certificate"))?;
376    let key = fs_err::tokio::read(chain.as_ref()).await?;
377    let key_cert: PrivateKeyDer =
378        PrivateKeyDer::from_pem_slice(&key).map_err(|_| io_other("could not parse pem file"))?;
379
380    ServerConfig::builder()
381        .with_no_client_auth()
382        .with_single_cert(cert, key_cert)
383        .map_err(|_| io_other("invalid certificate"))
384}
385
386#[cfg(test)]
387mod tests {
388    use crate::handle::Handle;
389    use crate::tls_rustls::{self, RustlsConfig};
390    use axum::body::Body;
391    use axum::routing::get;
392    use axum::Router;
393    use bytes::Bytes;
394    use http::{response, Request};
395    use http_body_util::BodyExt;
396    use hyper::client::conn::http1::{handshake, SendRequest};
397    use hyper_util::rt::TokioIo;
398    use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
399    use rustls::{ClientConfig, DigitallySignedStruct, Error, SignatureScheme};
400    use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
401    use std::fmt::Debug;
402    use std::{convert::TryFrom, io, net::SocketAddr, sync::Arc, time::Duration};
403    use tokio::time::sleep;
404    use tokio::{net::TcpStream, task::JoinHandle};
405    use tokio_rustls::TlsConnector;
406
407    #[tokio::test]
408    async fn start_and_request() {
409        let (_handle, _server_task, addr) = start_server().await;
410
411        let (mut client, _conn) = connect(addr).await;
412
413        let (_parts, body) = send_empty_request(&mut client).await;
414
415        assert_eq!(body.as_ref(), b"Hello, world!");
416    }
417
418    #[ignore]
419    #[tokio::test]
420    async fn tls_timeout() {
421        let (handle, _server_task, addr) = start_server().await;
422        assert_eq!(handle.connection_count(), 0);
423
424        // We intentionally avoid driving a TLS handshake to completion.
425        let _stream = TcpStream::connect(addr).await.unwrap();
426
427        sleep(Duration::from_millis(500)).await;
428        assert_eq!(handle.connection_count(), 1);
429
430        tokio::time::sleep(Duration::from_millis(1000)).await;
431        // Timeout defaults to 1s during testing, and we have waited 1.5 seconds.
432        assert_eq!(handle.connection_count(), 0);
433    }
434
435    #[tokio::test]
436    async fn test_reload() {
437        let handle = Handle::new();
438
439        let config = RustlsConfig::from_pem_file(
440            "examples/self-signed-certs/cert.pem",
441            "examples/self-signed-certs/key.pem",
442        )
443        .await
444        .unwrap();
445
446        let server_handle = handle.clone();
447        let rustls_config = config.clone();
448        tokio::spawn(async move {
449            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
450
451            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
452
453            tls_rustls::bind_rustls(addr, rustls_config)
454                .handle(server_handle)
455                .serve(app.into_make_service())
456                .await
457        });
458
459        let addr = handle.listening().await.unwrap();
460
461        let cert_a = get_first_cert(addr).await;
462        let mut cert_b = get_first_cert(addr).await;
463
464        assert_eq!(cert_a, cert_b);
465
466        config
467            .reload_from_pem_file(
468                "examples/self-signed-certs/reload/cert.pem",
469                "examples/self-signed-certs/reload/key.pem",
470            )
471            .await
472            .unwrap();
473
474        cert_b = get_first_cert(addr).await;
475
476        assert_ne!(cert_a, cert_b);
477
478        config
479            .reload_from_pem_file(
480                "examples/self-signed-certs/cert.pem",
481                "examples/self-signed-certs/key.pem",
482            )
483            .await
484            .unwrap();
485
486        cert_b = get_first_cert(addr).await;
487
488        assert_eq!(cert_a, cert_b);
489    }
490
491    async fn start_server() -> (Handle<SocketAddr>, JoinHandle<io::Result<()>>, SocketAddr) {
492        let handle = Handle::new();
493
494        let server_handle = handle.clone();
495        let server_task = tokio::spawn(async move {
496            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
497
498            let config = RustlsConfig::from_pem_file(
499                "examples/self-signed-certs/cert.pem",
500                "examples/self-signed-certs/key.pem",
501            )
502            .await?;
503
504            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
505
506            tls_rustls::bind_rustls(addr, config)
507                .handle(server_handle)
508                .serve(app.into_make_service())
509                .await
510        });
511
512        let addr = handle.listening().await.unwrap();
513
514        (handle, server_task, addr)
515    }
516
517    async fn get_first_cert(addr: SocketAddr) -> CertificateDer<'static> {
518        let stream = TcpStream::connect(addr).await.unwrap();
519        let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();
520
521        let (_io, client_connection) = tls_stream.into_inner();
522
523        client_connection.peer_certificates().unwrap()[0].clone()
524    }
525
526    async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
527        let stream = TcpStream::connect(addr).await.unwrap();
528        let tls_stream = TokioIo::new(tls_connector().connect(dns_name(), stream).await.unwrap());
529
530        let (send_request, connection) = handshake(tls_stream).await.unwrap();
531
532        let task = tokio::spawn(async move {
533            let _ = connection.await;
534        });
535
536        (send_request, task)
537    }
538
539    async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
540        let (parts, body) = client
541            .send_request(Request::new(Body::empty()))
542            .await
543            .unwrap()
544            .into_parts();
545        let body = body.collect().await.unwrap().to_bytes();
546
547        (parts, body)
548    }
549
550    fn tls_connector() -> TlsConnector {
551        #[derive(Debug)]
552        struct NoVerify;
553
554        impl ServerCertVerifier for NoVerify {
555            fn verify_server_cert(
556                &self,
557                _end_entity: &CertificateDer,
558                _intermediates: &[CertificateDer],
559                _server_name: &ServerName,
560                _ocsp_response: &[u8],
561                _now: UnixTime,
562            ) -> Result<ServerCertVerified, rustls::Error> {
563                Ok(ServerCertVerified::assertion())
564            }
565
566            fn verify_tls12_signature(
567                &self,
568                _message: &[u8],
569                _cert: &CertificateDer<'_>,
570                _dss: &DigitallySignedStruct,
571            ) -> Result<HandshakeSignatureValid, Error> {
572                Ok(HandshakeSignatureValid::assertion())
573            }
574
575            fn verify_tls13_signature(
576                &self,
577                _message: &[u8],
578                _cert: &CertificateDer<'_>,
579                _dss: &DigitallySignedStruct,
580            ) -> Result<HandshakeSignatureValid, Error> {
581                Ok(HandshakeSignatureValid::assertion())
582            }
583
584            fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
585                vec![
586                    SignatureScheme::RSA_PKCS1_SHA1,
587                    SignatureScheme::RSA_PKCS1_SHA256,
588                    SignatureScheme::RSA_PKCS1_SHA384,
589                    SignatureScheme::RSA_PKCS1_SHA512,
590                    SignatureScheme::RSA_PSS_SHA256,
591                    SignatureScheme::RSA_PSS_SHA384,
592                    SignatureScheme::RSA_PSS_SHA512,
593                ]
594            }
595        }
596
597        let mut client_config = ClientConfig::builder()
598            .dangerous()
599            .with_custom_certificate_verifier(Arc::new(NoVerify))
600            .with_no_client_auth();
601
602        client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
603
604        TlsConnector::from(Arc::new(client_config))
605    }
606
607    fn dns_name() -> ServerName<'static> {
608        ServerName::try_from("localhost").unwrap()
609    }
610
611    #[tokio::test]
612    async fn from_pem_file_not_found() {
613        let err = RustlsConfig::from_pem_file(
614            "examples/self-signed-certs/missing.pem",
615            "examples/self-signed-certs/key.pem",
616        )
617        .await
618        .unwrap_err();
619        assert_eq!(err.kind(), io::ErrorKind::NotFound);
620        assert_eq!(
621            err.to_string(),
622            "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)"
623        );
624
625        let err = RustlsConfig::from_pem_file(
626            "examples/self-signed-certs/cert.pem",
627            "examples/self-signed-certs/missing.pem",
628        )
629        .await
630        .unwrap_err();
631        assert_eq!(err.kind(), io::ErrorKind::NotFound);
632        assert_eq!(
633            err.to_string(),
634            "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)"
635        );
636    }
637
638    #[tokio::test]
639    async fn from_pem_file_chain_file_not_found() {
640        let err = RustlsConfig::from_pem_chain_file(
641            "examples/self-signed-certs/missing.pem",
642            "examples/self-signed-certs/key.pem",
643        )
644        .await
645        .unwrap_err();
646        assert_eq!(err.kind(), io::ErrorKind::NotFound);
647        assert_eq!(
648            err.to_string(),
649            "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)"
650        );
651
652        let err = RustlsConfig::from_pem_chain_file(
653            "examples/self-signed-certs/cert.pem",
654            "examples/self-signed-certs/missing.pem",
655        )
656        .await
657        .unwrap_err();
658        assert_eq!(err.kind(), io::ErrorKind::NotFound);
659        assert_eq!(
660            err.to_string(),
661            "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)"
662        );
663    }
664}