axum_server/tls_rustls/
mod.rs

1//! Tls implementation using [`rustls`].
2//!
3//! # Example
4//!
5//! ```rust,no_run
6//! use axum::{routing::get, Router};
7//! use axum_server::tls_rustls::RustlsConfig;
8//! use std::net::SocketAddr;
9//!
10//! #[tokio::main]
11//! async fn main() {
12//!     let app = Router::new().route("/", get(|| async { "Hello, world!" }));
13//!
14//!     let config = RustlsConfig::from_pem_file(
15//!         "examples/self-signed-certs/cert.pem",
16//!         "examples/self-signed-certs/key.pem",
17//!     )
18//!     .await
19//!     .unwrap();
20//!
21//!     let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
22//!     println!("listening on {}", addr);
23//!     axum_server::bind_rustls(addr, config)
24//!         .serve(app.into_make_service())
25//!         .await
26//!         .unwrap();
27//! }
28//! ```
29
30use self::future::RustlsAcceptorFuture;
31use crate::{
32    accept::{Accept, DefaultAcceptor},
33    server::{io_other, Server},
34};
35use arc_swap::ArcSwap;
36use rustls::ServerConfig;
37use rustls_pemfile::Item;
38use rustls_pki_types::{CertificateDer, PrivateKeyDer};
39use std::time::Duration;
40use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc};
41use tokio::{
42    io::{AsyncRead, AsyncWrite},
43    task::spawn_blocking,
44};
45use tokio_rustls::server::TlsStream;
46
47pub(crate) mod export {
48    #[allow(clippy::wildcard_imports)]
49    use super::*;
50
51    /// Create a tls server that will bind to provided address.
52    #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
53    pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
54        super::bind_rustls(addr, config)
55    }
56
57    /// Create a tls server from existing `std::net::TcpListener`.
58    #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
59    pub fn from_tcp_rustls(
60        listener: std::net::TcpListener,
61        config: RustlsConfig,
62    ) -> Server<RustlsAcceptor> {
63        let acceptor = RustlsAcceptor::new(config);
64
65        Server::from_tcp(listener).acceptor(acceptor)
66    }
67}
68
69pub mod future;
70
71/// Create a tls server that will bind to provided address.
72pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
73    let acceptor = RustlsAcceptor::new(config);
74
75    Server::bind(addr).acceptor(acceptor)
76}
77
78/// Create a tls server from existing `std::net::TcpListener`.
79pub fn from_tcp_rustls(
80    listener: std::net::TcpListener,
81    config: RustlsConfig,
82) -> Server<RustlsAcceptor> {
83    let acceptor = RustlsAcceptor::new(config);
84
85    Server::from_tcp(listener).acceptor(acceptor)
86}
87
88/// Tls acceptor using rustls.
89#[derive(Clone)]
90pub struct RustlsAcceptor<A = DefaultAcceptor> {
91    inner: A,
92    config: RustlsConfig,
93    handshake_timeout: Duration,
94}
95
96impl RustlsAcceptor {
97    /// Create a new rustls acceptor.
98    pub fn new(config: RustlsConfig) -> Self {
99        let inner = DefaultAcceptor::new();
100
101        #[cfg(not(test))]
102        let handshake_timeout = Duration::from_secs(10);
103
104        // Don't force tests to wait too long.
105        #[cfg(test)]
106        let handshake_timeout = Duration::from_secs(1);
107
108        Self {
109            inner,
110            config,
111            handshake_timeout,
112        }
113    }
114
115    /// Override the default TLS handshake timeout of 10 seconds, except during testing.
116    pub fn handshake_timeout(mut self, val: Duration) -> Self {
117        self.handshake_timeout = val;
118        self
119    }
120}
121
122impl<A> RustlsAcceptor<A> {
123    /// Overwrite inner acceptor.
124    pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> RustlsAcceptor<Acceptor> {
125        RustlsAcceptor {
126            inner: acceptor,
127            config: self.config,
128            handshake_timeout: self.handshake_timeout,
129        }
130    }
131}
132
133impl<A, I, S> Accept<I, S> for RustlsAcceptor<A>
134where
135    A: Accept<I, S>,
136    A::Stream: AsyncRead + AsyncWrite + Unpin,
137{
138    type Stream = TlsStream<A::Stream>;
139    type Service = A::Service;
140    type Future = RustlsAcceptorFuture<A::Future, A::Stream, A::Service>;
141
142    fn accept(&self, stream: I, service: S) -> Self::Future {
143        let inner_future = self.inner.accept(stream, service);
144        let config = self.config.clone();
145
146        RustlsAcceptorFuture::new(inner_future, config, self.handshake_timeout)
147    }
148}
149
150impl<A> fmt::Debug for RustlsAcceptor<A> {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        f.debug_struct("RustlsAcceptor").finish()
153    }
154}
155
156/// Rustls configuration.
157#[derive(Clone)]
158pub struct RustlsConfig {
159    inner: Arc<ArcSwap<ServerConfig>>,
160}
161
162impl RustlsConfig {
163    /// Create config from `Arc<`[`ServerConfig`]`>`.
164    ///
165    /// NOTE: You need to set ALPN protocols (like `http/1.1` or `h2`) manually.
166    pub fn from_config(config: Arc<ServerConfig>) -> Self {
167        let inner = Arc::new(ArcSwap::new(config));
168
169        Self { inner }
170    }
171
172    /// Create config from DER-encoded data.
173    ///
174    /// The certificate must be DER-encoded X.509.
175    ///
176    /// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
177    pub async fn from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<Self> {
178        let server_config = spawn_blocking(|| config_from_der(cert, key))
179            .await
180            .unwrap()?;
181        let inner = Arc::new(ArcSwap::from_pointee(server_config));
182
183        Ok(Self { inner })
184    }
185
186    /// Create config from PEM formatted data.
187    ///
188    /// Certificate and private key must be in PEM format.
189    pub async fn from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<Self> {
190        let server_config = spawn_blocking(|| config_from_pem(cert, key))
191            .await
192            .unwrap()?;
193        let inner = Arc::new(ArcSwap::from_pointee(server_config));
194
195        Ok(Self { inner })
196    }
197
198    /// Create config from PEM formatted files.
199    ///
200    /// Contents of certificate file and private key file must be in PEM format.
201    pub async fn from_pem_file(cert: impl AsRef<Path>, key: impl AsRef<Path>) -> io::Result<Self> {
202        let server_config = config_from_pem_file(cert, key).await?;
203        let inner = Arc::new(ArcSwap::from_pointee(server_config));
204
205        Ok(Self { inner })
206    }
207
208    /// Get  inner `Arc<`[`ServerConfig`]`>`.
209    pub fn get_inner(&self) -> Arc<ServerConfig> {
210        self.inner.load_full()
211    }
212
213    /// Reload config from `Arc<`[`ServerConfig`]`>`.
214    pub fn reload_from_config(&self, config: Arc<ServerConfig>) {
215        self.inner.store(config);
216    }
217
218    /// Reload config from DER-encoded data.
219    ///
220    /// The certificate must be DER-encoded X.509.
221    ///
222    /// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
223    pub async fn reload_from_der(&self, cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<()> {
224        let server_config = spawn_blocking(|| config_from_der(cert, key))
225            .await
226            .unwrap()?;
227        let inner = Arc::new(server_config);
228
229        self.inner.store(inner);
230
231        Ok(())
232    }
233
234    /// This helper will establish a TLS server based on strong cipher suites
235    /// from a PEM-formatted certificate chain and key.
236    pub async fn from_pem_chain_file(
237        chain: impl AsRef<Path>,
238        key: impl AsRef<Path>,
239    ) -> io::Result<Self> {
240        let server_config = config_from_pem_chain_file(chain, key).await?;
241        let inner = Arc::new(ArcSwap::from_pointee(server_config));
242
243        Ok(Self { inner })
244    }
245
246    /// Reload config from PEM formatted data.
247    ///
248    /// Certificate and private key must be in PEM format.
249    pub async fn reload_from_pem(&self, cert: Vec<u8>, key: Vec<u8>) -> io::Result<()> {
250        let server_config = spawn_blocking(|| config_from_pem(cert, key))
251            .await
252            .unwrap()?;
253        let inner = Arc::new(server_config);
254
255        self.inner.store(inner);
256
257        Ok(())
258    }
259
260    /// Reload config from PEM formatted files.
261    ///
262    /// Contents of certificate file and private key file must be in PEM format.
263    pub async fn reload_from_pem_file(
264        &self,
265        cert: impl AsRef<Path>,
266        key: impl AsRef<Path>,
267    ) -> io::Result<()> {
268        let server_config = config_from_pem_file(cert, key).await?;
269        let inner = Arc::new(server_config);
270
271        self.inner.store(inner);
272
273        Ok(())
274    }
275}
276
277impl fmt::Debug for RustlsConfig {
278    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279        f.debug_struct("RustlsConfig").finish()
280    }
281}
282
283fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig> {
284    let cert = cert.into_iter().map(CertificateDer::from).collect();
285    let key = PrivateKeyDer::try_from(key).map_err(io_other)?;
286
287    let mut config = ServerConfig::builder()
288        .with_no_client_auth()
289        .with_single_cert(cert, key)
290        .map_err(io_other)?;
291
292    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
293
294    Ok(config)
295}
296
297fn config_from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<ServerConfig> {
298    let cert = rustls_pemfile::certs(&mut cert.as_ref())
299        .map(|it| it.map(|it| it.to_vec()))
300        .collect::<Result<Vec<_>, _>>()?;
301    // Check the entire PEM file for the key in case it is not first section
302    let mut key_vec: Vec<Vec<u8>> = rustls_pemfile::read_all(&mut key.as_ref())
303        .filter_map(|i| match i.ok()? {
304            Item::Sec1Key(key) => Some(key.secret_sec1_der().to_vec()),
305            Item::Pkcs1Key(key) => Some(key.secret_pkcs1_der().to_vec()),
306            Item::Pkcs8Key(key) => Some(key.secret_pkcs8_der().to_vec()),
307            _ => None,
308        })
309        .collect();
310
311    // Make sure file contains only one key
312    if key_vec.len() != 1 {
313        return Err(io_other("private key format not supported"));
314    }
315
316    config_from_der(cert, key_vec.pop().unwrap())
317}
318
319async fn config_from_pem_file(
320    cert: impl AsRef<Path>,
321    key: impl AsRef<Path>,
322) -> io::Result<ServerConfig> {
323    let cert = fs_err::tokio::read(cert.as_ref()).await?;
324    let key = fs_err::tokio::read(key.as_ref()).await?;
325
326    config_from_pem(cert, key)
327}
328
329async fn config_from_pem_chain_file(
330    cert: impl AsRef<Path>,
331    chain: impl AsRef<Path>,
332) -> io::Result<ServerConfig> {
333    let cert = fs_err::tokio::read(cert.as_ref()).await?;
334    let cert = rustls_pemfile::certs(&mut cert.as_ref())
335        .map(|it| it.map(|it| CertificateDer::from(it.to_vec())))
336        .collect::<Result<Vec<_>, _>>()?;
337    let key = fs_err::tokio::read(chain.as_ref()).await?;
338    let key_cert: PrivateKeyDer = match rustls_pemfile::read_one(&mut key.as_ref())?
339        .ok_or_else(|| io_other("could not parse pem file"))?
340    {
341        Item::Pkcs8Key(key) => Ok(key.into()),
342        Item::Sec1Key(key) => Ok(key.into()),
343        Item::Pkcs1Key(key) => Ok(key.into()),
344        x => Err(io_other(format!(
345            "invalid certificate format, received: {x:?}"
346        ))),
347    }?;
348
349    ServerConfig::builder()
350        .with_no_client_auth()
351        .with_single_cert(cert, key_cert)
352        .map_err(|_| io_other("invalid certificate"))
353}
354
355#[cfg(test)]
356mod tests {
357    use crate::handle::Handle;
358    use crate::tls_rustls::{self, RustlsConfig};
359    use axum::body::Body;
360    use axum::routing::get;
361    use axum::Router;
362    use bytes::Bytes;
363    use http::{response, Request};
364    use http_body_util::BodyExt;
365    use hyper::client::conn::http1::{handshake, SendRequest};
366    use hyper_util::rt::TokioIo;
367    use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
368    use rustls::{ClientConfig, DigitallySignedStruct, Error, SignatureScheme};
369    use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
370    use std::fmt::Debug;
371    use std::{convert::TryFrom, io, net::SocketAddr, sync::Arc, time::Duration};
372    use tokio::time::sleep;
373    use tokio::{net::TcpStream, task::JoinHandle};
374    use tokio_rustls::TlsConnector;
375
376    #[tokio::test]
377    async fn start_and_request() {
378        let (_handle, _server_task, addr) = start_server().await;
379
380        let (mut client, _conn) = connect(addr).await;
381
382        let (_parts, body) = send_empty_request(&mut client).await;
383
384        assert_eq!(body.as_ref(), b"Hello, world!");
385    }
386
387    #[ignore]
388    #[tokio::test]
389    async fn tls_timeout() {
390        let (handle, _server_task, addr) = start_server().await;
391        assert_eq!(handle.connection_count(), 0);
392
393        // We intentionally avoid driving a TLS handshake to completion.
394        let _stream = TcpStream::connect(addr).await.unwrap();
395
396        sleep(Duration::from_millis(500)).await;
397        assert_eq!(handle.connection_count(), 1);
398
399        tokio::time::sleep(Duration::from_millis(1000)).await;
400        // Timeout defaults to 1s during testing, and we have waited 1.5 seconds.
401        assert_eq!(handle.connection_count(), 0);
402    }
403
404    #[tokio::test]
405    async fn test_reload() {
406        let handle = Handle::new();
407
408        let config = RustlsConfig::from_pem_file(
409            "examples/self-signed-certs/cert.pem",
410            "examples/self-signed-certs/key.pem",
411        )
412        .await
413        .unwrap();
414
415        let server_handle = handle.clone();
416        let rustls_config = config.clone();
417        tokio::spawn(async move {
418            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
419
420            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
421
422            tls_rustls::bind_rustls(addr, rustls_config)
423                .handle(server_handle)
424                .serve(app.into_make_service())
425                .await
426        });
427
428        let addr = handle.listening().await.unwrap();
429
430        let cert_a = get_first_cert(addr).await;
431        let mut cert_b = get_first_cert(addr).await;
432
433        assert_eq!(cert_a, cert_b);
434
435        config
436            .reload_from_pem_file(
437                "examples/self-signed-certs/reload/cert.pem",
438                "examples/self-signed-certs/reload/key.pem",
439            )
440            .await
441            .unwrap();
442
443        cert_b = get_first_cert(addr).await;
444
445        assert_ne!(cert_a, cert_b);
446
447        config
448            .reload_from_pem_file(
449                "examples/self-signed-certs/cert.pem",
450                "examples/self-signed-certs/key.pem",
451            )
452            .await
453            .unwrap();
454
455        cert_b = get_first_cert(addr).await;
456
457        assert_eq!(cert_a, cert_b);
458    }
459
460    async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
461        let handle = Handle::new();
462
463        let server_handle = handle.clone();
464        let server_task = tokio::spawn(async move {
465            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
466
467            let config = RustlsConfig::from_pem_file(
468                "examples/self-signed-certs/cert.pem",
469                "examples/self-signed-certs/key.pem",
470            )
471            .await?;
472
473            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
474
475            tls_rustls::bind_rustls(addr, config)
476                .handle(server_handle)
477                .serve(app.into_make_service())
478                .await
479        });
480
481        let addr = handle.listening().await.unwrap();
482
483        (handle, server_task, addr)
484    }
485
486    async fn get_first_cert(addr: SocketAddr) -> CertificateDer<'static> {
487        let stream = TcpStream::connect(addr).await.unwrap();
488        let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();
489
490        let (_io, client_connection) = tls_stream.into_inner();
491
492        client_connection.peer_certificates().unwrap()[0].clone()
493    }
494
495    async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
496        let stream = TcpStream::connect(addr).await.unwrap();
497        let tls_stream = TokioIo::new(tls_connector().connect(dns_name(), stream).await.unwrap());
498
499        let (send_request, connection) = handshake(tls_stream).await.unwrap();
500
501        let task = tokio::spawn(async move {
502            let _ = connection.await;
503        });
504
505        (send_request, task)
506    }
507
508    async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
509        let (parts, body) = client
510            .send_request(Request::new(Body::empty()))
511            .await
512            .unwrap()
513            .into_parts();
514        let body = body.collect().await.unwrap().to_bytes();
515
516        (parts, body)
517    }
518
519    fn tls_connector() -> TlsConnector {
520        #[derive(Debug)]
521        struct NoVerify;
522
523        impl ServerCertVerifier for NoVerify {
524            fn verify_server_cert(
525                &self,
526                _end_entity: &CertificateDer,
527                _intermediates: &[CertificateDer],
528                _server_name: &ServerName,
529                _ocsp_response: &[u8],
530                _now: UnixTime,
531            ) -> Result<ServerCertVerified, rustls::Error> {
532                Ok(ServerCertVerified::assertion())
533            }
534
535            fn verify_tls12_signature(
536                &self,
537                _message: &[u8],
538                _cert: &CertificateDer<'_>,
539                _dss: &DigitallySignedStruct,
540            ) -> Result<HandshakeSignatureValid, Error> {
541                Ok(HandshakeSignatureValid::assertion())
542            }
543
544            fn verify_tls13_signature(
545                &self,
546                _message: &[u8],
547                _cert: &CertificateDer<'_>,
548                _dss: &DigitallySignedStruct,
549            ) -> Result<HandshakeSignatureValid, Error> {
550                Ok(HandshakeSignatureValid::assertion())
551            }
552
553            fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
554                vec![
555                    SignatureScheme::RSA_PKCS1_SHA1,
556                    SignatureScheme::RSA_PKCS1_SHA256,
557                    SignatureScheme::RSA_PKCS1_SHA384,
558                    SignatureScheme::RSA_PKCS1_SHA512,
559                    SignatureScheme::RSA_PSS_SHA256,
560                    SignatureScheme::RSA_PSS_SHA384,
561                    SignatureScheme::RSA_PSS_SHA512,
562                ]
563            }
564        }
565
566        let mut client_config = ClientConfig::builder()
567            .dangerous()
568            .with_custom_certificate_verifier(Arc::new(NoVerify))
569            .with_no_client_auth();
570
571        client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
572
573        TlsConnector::from(Arc::new(client_config))
574    }
575
576    fn dns_name() -> ServerName<'static> {
577        ServerName::try_from("localhost").unwrap()
578    }
579
580    #[tokio::test]
581    async fn from_pem_file_not_found() {
582        let err = RustlsConfig::from_pem_file(
583            "examples/self-signed-certs/missing.pem",
584            "examples/self-signed-certs/key.pem",
585        )
586        .await
587        .unwrap_err();
588        assert_eq!(err.kind(), io::ErrorKind::NotFound);
589        assert_eq!(err.to_string(), "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)");
590
591        let err = RustlsConfig::from_pem_file(
592            "examples/self-signed-certs/cert.pem",
593            "examples/self-signed-certs/missing.pem",
594        )
595        .await
596        .unwrap_err();
597        assert_eq!(err.kind(), io::ErrorKind::NotFound);
598        assert_eq!(err.to_string(), "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)");
599    }
600
601    #[tokio::test]
602    async fn from_pem_file_chain_file_not_found() {
603        let err = RustlsConfig::from_pem_chain_file(
604            "examples/self-signed-certs/missing.pem",
605            "examples/self-signed-certs/key.pem",
606        )
607        .await
608        .unwrap_err();
609        assert_eq!(err.kind(), io::ErrorKind::NotFound);
610        assert_eq!(err.to_string(), "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)");
611
612        let err = RustlsConfig::from_pem_chain_file(
613            "examples/self-signed-certs/cert.pem",
614            "examples/self-signed-certs/missing.pem",
615        )
616        .await
617        .unwrap_err();
618        assert_eq!(err.kind(), io::ErrorKind::NotFound);
619        assert_eq!(err.to_string(), "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)");
620    }
621}