axum_server/tls_openssl/
mod.rs

1//! Tls implementation using [`openssl`]
2//!
3//! # Example
4//!
5//! ```rust,no_run
6//! use axum::{routing::get, Router};
7//! use axum_server::tls_openssl::OpenSSLConfig;
8//! use std::net::SocketAddr;
9//!
10//! #[tokio::main]
11//! async fn main() {
12//!     let app = Router::new().route("/", get(|| async { "Hello, world!" }));
13//!
14//!     let config = OpenSSLConfig::from_pem_file(
15//!         "examples/self-signed-certs/cert.pem",
16//!         "examples/self-signed-certs/key.pem",
17//!     )
18//!     .unwrap();
19//!
20//!     let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
21//!     println!("listening on {}", addr);
22//!     axum_server::bind_openssl(addr, config)
23//!         .serve(app.into_make_service())
24//!         .await
25//!         .unwrap();
26//! }
27//! ```
28
29use self::future::OpenSSLAcceptorFuture;
30use crate::{
31    accept::{Accept, DefaultAcceptor},
32    server::Server,
33    Address,
34};
35use arc_swap::ArcSwap;
36use openssl::{
37    pkey::PKey,
38    ssl::{
39        self, AlpnError, Error as OpenSSLError, SslAcceptor, SslAcceptorBuilder, SslFiletype,
40        SslMethod, SslRef,
41    },
42    x509::X509,
43};
44use std::{convert::TryFrom, fmt, path::Path, sync::Arc, time::Duration};
45use tokio::io::{AsyncRead, AsyncWrite};
46use tokio_openssl::SslStream;
47
48pub mod future;
49
50/// Create a TLS server that will be bound to the provided socket with a configuration. See
51/// the [`crate::tls_openssl`] module for more details.
52pub fn bind_openssl<A: Address>(addr: A, config: OpenSSLConfig) -> Server<A, OpenSSLAcceptor> {
53    let acceptor = OpenSSLAcceptor::new(config);
54
55    Server::bind(addr).acceptor(acceptor)
56}
57
58/// Tls acceptor that uses OpenSSL. For details on how to use this see [`crate::tls_openssl`] module
59/// for more details.
60#[derive(Clone)]
61pub struct OpenSSLAcceptor<A = DefaultAcceptor> {
62    inner: A,
63    config: OpenSSLConfig,
64    handshake_timeout: Duration,
65}
66
67impl OpenSSLAcceptor {
68    /// Create a new OpenSSL acceptor based on the provided [`OpenSSLConfig`]. This is
69    /// generally used with manual calls to [`Server::bind`]. You may want [`bind_openssl`]
70    /// instead.
71    pub fn new(config: OpenSSLConfig) -> Self {
72        let inner = DefaultAcceptor::new();
73
74        #[cfg(not(test))]
75        let handshake_timeout = Duration::from_secs(10);
76
77        // Don't force tests to wait too long.
78        #[cfg(test)]
79        let handshake_timeout = Duration::from_secs(1);
80
81        Self {
82            inner,
83            config,
84            handshake_timeout,
85        }
86    }
87
88    /// Override the default TLS handshake timeout of 10 seconds.
89    pub fn handshake_timeout(mut self, val: Duration) -> Self {
90        self.handshake_timeout = val;
91        self
92    }
93}
94
95impl<A> OpenSSLAcceptor<A> {
96    /// Overwrite inner acceptor.
97    pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> OpenSSLAcceptor<Acceptor> {
98        OpenSSLAcceptor {
99            inner: acceptor,
100            config: self.config,
101            handshake_timeout: self.handshake_timeout,
102        }
103    }
104}
105
106impl<A, I, S> Accept<I, S> for OpenSSLAcceptor<A>
107where
108    A: Accept<I, S>,
109    A::Stream: AsyncRead + AsyncWrite + Unpin,
110{
111    type Stream = SslStream<A::Stream>;
112    type Service = A::Service;
113    type Future = OpenSSLAcceptorFuture<A::Future, A::Stream, A::Service>;
114
115    fn accept(&self, stream: I, service: S) -> Self::Future {
116        let inner_future = self.inner.accept(stream, service);
117        let config = self.config.clone();
118
119        OpenSSLAcceptorFuture::new(inner_future, config, self.handshake_timeout)
120    }
121}
122
123impl<A> fmt::Debug for OpenSSLAcceptor<A> {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        f.debug_struct("OpenSSLAcceptor").finish()
126    }
127}
128
129/// OpenSSL configuration.
130#[derive(Clone)]
131pub struct OpenSSLConfig {
132    acceptor: Arc<ArcSwap<SslAcceptor>>,
133}
134
135impl OpenSSLConfig {
136    /// Create config from `Arc<`[`SslAcceptor`]`>`.
137    pub fn from_acceptor(acceptor: Arc<SslAcceptor>) -> Self {
138        let acceptor = Arc::new(ArcSwap::new(acceptor));
139
140        OpenSSLConfig { acceptor }
141    }
142
143    /// This helper will establish a TLS server based on strong cipher suites
144    /// from a DER-encoded certificate and key.
145    pub fn from_der(cert: &[u8], key: &[u8]) -> Result<Self, OpenSSLError> {
146        let acceptor = Arc::new(ArcSwap::from_pointee(config_from_der(cert, key)?));
147
148        Ok(OpenSSLConfig { acceptor })
149    }
150
151    /// This helper will establish a TLS server based on strong cipher suites
152    /// from a PEM-formatted certificate and key.
153    pub fn from_pem(cert: &[u8], key: &[u8]) -> Result<Self, OpenSSLError> {
154        let acceptor = Arc::new(ArcSwap::from_pointee(config_from_pem(cert, key)?));
155
156        Ok(OpenSSLConfig { acceptor })
157    }
158
159    /// This helper will establish a TLS server based on strong cipher suites
160    /// from a PEM-formatted certificate and key.
161    pub fn from_pem_file(
162        cert: impl AsRef<Path>,
163        key: impl AsRef<Path>,
164    ) -> Result<Self, OpenSSLError> {
165        let acceptor = Arc::new(ArcSwap::from_pointee(config_from_pem_file(cert, key)?));
166
167        Ok(OpenSSLConfig { acceptor })
168    }
169
170    /// This helper will establish a TLS server based on strong cipher suites
171    /// from a PEM-formatted certificate chain and key.
172    pub fn from_pem_chain_file(
173        chain: impl AsRef<Path>,
174        key: impl AsRef<Path>,
175    ) -> Result<Self, OpenSSLError> {
176        let acceptor = Arc::new(ArcSwap::from_pointee(config_from_pem_chain_file(
177            chain, key,
178        )?));
179
180        Ok(OpenSSLConfig { acceptor })
181    }
182
183    /// Get inner `Arc<`[`SslAcceptor`]`>`.
184    pub fn get_inner(&self) -> Arc<SslAcceptor> {
185        self.acceptor.load_full()
186    }
187
188    /// Reload acceptor from `Arc<`[`SslAcceptor`]`>`.
189    pub fn reload_from_acceptor(&self, acceptor: Arc<SslAcceptor>) {
190        self.acceptor.store(acceptor);
191    }
192
193    /// Reload acceptor from a DER-encoded certificate and key.
194    pub fn reload_from_der(&self, cert: &[u8], key: &[u8]) -> Result<(), OpenSSLError> {
195        let acceptor = Arc::new(config_from_der(cert, key)?);
196        self.acceptor.store(acceptor);
197
198        Ok(())
199    }
200
201    /// Reload acceptor from a PEM-formatted certificate and key.
202    pub fn reload_from_pem(&self, cert: &[u8], key: &[u8]) -> Result<(), OpenSSLError> {
203        let acceptor = Arc::new(config_from_pem(cert, key)?);
204        self.acceptor.store(acceptor);
205
206        Ok(())
207    }
208
209    /// Reload acceptor from a PEM-formatted certificate and key.
210    pub fn reload_from_pem_file(
211        &self,
212        cert: impl AsRef<Path>,
213        key: impl AsRef<Path>,
214    ) -> Result<(), OpenSSLError> {
215        let acceptor = Arc::new(config_from_pem_file(cert, key)?);
216        self.acceptor.store(acceptor);
217
218        Ok(())
219    }
220
221    /// Reload acceptor from a PEM-formatted certificate chain and key.
222    pub fn reload_from_pem_chain_file(
223        &self,
224        chain: impl AsRef<Path>,
225        key: impl AsRef<Path>,
226    ) -> Result<(), OpenSSLError> {
227        let acceptor = Arc::new(config_from_pem_chain_file(chain, key)?);
228        self.acceptor.store(acceptor);
229
230        Ok(())
231    }
232}
233
234impl TryFrom<SslAcceptorBuilder> for OpenSSLConfig {
235    type Error = OpenSSLError;
236
237    /// Build the [`OpenSSLConfig`] from an [`SslAcceptorBuilder`]. This allows precise
238    /// control over the settings that will be used by OpenSSL in this server.
239    ///
240    /// # Example
241    /// ```
242    /// use axum_server::tls_openssl::OpenSSLConfig;
243    /// use openssl::ssl::{SslAcceptor, SslMethod};
244    /// use std::convert::TryFrom;
245    ///
246    /// #[tokio::main]
247    /// async fn main() {
248    ///     let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())
249    ///         .unwrap();
250    ///     // Set configurations like set_certificate_chain_file or
251    ///     // set_private_key_file.
252    ///     // let tls_builder.set_ ... ;
253    ///     let _config = OpenSSLConfig::try_from(tls_builder);
254    /// }
255    /// ```
256    fn try_from(mut tls_builder: SslAcceptorBuilder) -> Result<Self, Self::Error> {
257        // Any other checks?
258        tls_builder.check_private_key()?;
259        tls_builder.set_alpn_select_callback(alpn_select);
260
261        let acceptor = Arc::new(ArcSwap::from_pointee(tls_builder.build()));
262
263        Ok(OpenSSLConfig { acceptor })
264    }
265}
266
267impl fmt::Debug for OpenSSLConfig {
268    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269        f.debug_struct("OpenSSLConfig").finish()
270    }
271}
272
273fn alpn_select<'a>(_tls: &mut SslRef, client: &'a [u8]) -> Result<&'a [u8], AlpnError> {
274    ssl::select_next_proto(b"\x02h2\x08http/1.1", client).ok_or(AlpnError::NOACK)
275}
276
277fn config_from_der(cert: &[u8], key: &[u8]) -> Result<SslAcceptor, OpenSSLError> {
278    let cert = X509::from_der(cert)?;
279    let key = PKey::private_key_from_der(key)?;
280
281    let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
282    tls_builder.set_certificate(&cert)?;
283    tls_builder.set_private_key(&key)?;
284    tls_builder.check_private_key()?;
285    tls_builder.set_alpn_select_callback(alpn_select);
286
287    let acceptor = tls_builder.build();
288    Ok(acceptor)
289}
290
291fn config_from_pem(cert: &[u8], key: &[u8]) -> Result<SslAcceptor, OpenSSLError> {
292    let cert = X509::from_pem(cert)?;
293    let key = PKey::private_key_from_pem(key)?;
294
295    let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
296    tls_builder.set_certificate(&cert)?;
297    tls_builder.set_private_key(&key)?;
298    tls_builder.check_private_key()?;
299    tls_builder.set_alpn_select_callback(alpn_select);
300
301    let acceptor = tls_builder.build();
302    Ok(acceptor)
303}
304
305fn config_from_pem_file(
306    cert: impl AsRef<Path>,
307    key: impl AsRef<Path>,
308) -> Result<SslAcceptor, OpenSSLError> {
309    let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
310    tls_builder.set_certificate_file(cert, SslFiletype::PEM)?;
311    tls_builder.set_private_key_file(key, SslFiletype::PEM)?;
312    tls_builder.check_private_key()?;
313    tls_builder.set_alpn_select_callback(alpn_select);
314
315    let acceptor = tls_builder.build();
316    Ok(acceptor)
317}
318
319fn config_from_pem_chain_file(
320    chain: impl AsRef<Path>,
321    key: impl AsRef<Path>,
322) -> Result<SslAcceptor, OpenSSLError> {
323    let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
324    tls_builder.set_certificate_chain_file(chain)?;
325    tls_builder.set_private_key_file(key, SslFiletype::PEM)?;
326    tls_builder.check_private_key()?;
327    tls_builder.set_alpn_select_callback(alpn_select);
328
329    let acceptor = tls_builder.build();
330    Ok(acceptor)
331}
332
333#[cfg(test)]
334mod tests {
335    use crate::{
336        handle::Handle,
337        tls_openssl::{self, OpenSSLConfig},
338    };
339    use axum::body::Body;
340    use axum::routing::get;
341    use axum::Router;
342    use bytes::Bytes;
343    use http::{response, Request};
344    use http_body_util::BodyExt;
345    use hyper::client::conn::http1::{handshake, SendRequest};
346    use hyper_util::rt::TokioIo;
347    use std::{io, net::SocketAddr};
348    use tokio::{net::TcpStream, task::JoinHandle};
349
350    use openssl::{
351        ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode},
352        x509::X509,
353    };
354    use std::pin::Pin;
355    use tokio_openssl::SslStream;
356
357    #[tokio::test]
358    async fn start_and_request() {
359        let (_handle, _server_task, addr) = start_server().await;
360
361        let (mut client, _conn) = connect(addr).await;
362
363        let (_parts, body) = send_empty_request(&mut client).await;
364
365        assert_eq!(body.as_ref(), b"Hello, world!");
366    }
367
368    #[tokio::test]
369    async fn test_reload() {
370        let handle = Handle::new();
371
372        let config = OpenSSLConfig::from_pem_file(
373            "examples/self-signed-certs/cert.pem",
374            "examples/self-signed-certs/key.pem",
375        )
376        .unwrap();
377
378        let server_handle = handle.clone();
379        let openssl_config = config.clone();
380        tokio::spawn(async move {
381            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
382
383            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
384
385            tls_openssl::bind_openssl(addr, openssl_config)
386                .handle(server_handle)
387                .serve(app.into_make_service())
388                .await
389        });
390
391        let addr = handle.listening().await.unwrap();
392
393        let cert_a = get_first_cert(addr).await;
394        let mut cert_b = get_first_cert(addr).await;
395
396        assert_eq!(cert_a, cert_b);
397
398        config
399            .reload_from_pem_file(
400                "examples/self-signed-certs/reload/cert.pem",
401                "examples/self-signed-certs/reload/key.pem",
402            )
403            .unwrap();
404
405        cert_b = get_first_cert(addr).await;
406
407        assert_ne!(cert_a, cert_b);
408
409        config
410            .reload_from_pem_file(
411                "examples/self-signed-certs/cert.pem",
412                "examples/self-signed-certs/key.pem",
413            )
414            .unwrap();
415
416        cert_b = get_first_cert(addr).await;
417
418        assert_eq!(cert_a, cert_b);
419    }
420
421    async fn start_server() -> (Handle<SocketAddr>, JoinHandle<io::Result<()>>, SocketAddr) {
422        let handle = Handle::new();
423
424        let server_handle = handle.clone();
425        let server_task = tokio::spawn(async move {
426            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
427
428            let config = OpenSSLConfig::from_pem_file(
429                "examples/self-signed-certs/cert.pem",
430                "examples/self-signed-certs/key.pem",
431            )
432            .unwrap();
433
434            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
435
436            tls_openssl::bind_openssl(addr, config)
437                .handle(server_handle)
438                .serve(app.into_make_service())
439                .await
440        });
441
442        let addr = handle.listening().await.unwrap();
443
444        (handle, server_task, addr)
445    }
446
447    async fn get_first_cert(addr: SocketAddr) -> X509 {
448        let stream = TcpStream::connect(addr).await.unwrap();
449        let tls_stream = tls_connector(dns_name(), stream).await;
450
451        tls_stream.ssl().peer_certificate().unwrap()
452    }
453
454    async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
455        let stream = TcpStream::connect(addr).await.unwrap();
456        let tls_stream = TokioIo::new(tls_connector(dns_name(), stream).await);
457
458        let (send_request, connection) = handshake(tls_stream).await.unwrap();
459
460        let task = tokio::spawn(async move {
461            let _ = connection.await;
462        });
463
464        (send_request, task)
465    }
466
467    async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
468        let (parts, body) = client
469            .send_request(Request::new(Body::empty()))
470            .await
471            .unwrap()
472            .into_parts();
473        let body = body.collect().await.unwrap().to_bytes();
474
475        (parts, body)
476    }
477
478    async fn tls_connector(hostname: &str, stream: TcpStream) -> SslStream<TcpStream> {
479        let mut tls_parms = SslConnector::builder(SslMethod::tls_client()).unwrap();
480        tls_parms.set_verify(SslVerifyMode::NONE);
481        let hostname_owned = hostname.to_string();
482        tls_parms.set_client_hello_callback(move |ssl_ref, _ssl_alert| {
483            ssl_ref
484                .set_hostname(hostname_owned.as_str())
485                .map(|()| openssl::ssl::ClientHelloResponse::SUCCESS)
486        });
487        let tls_parms = tls_parms.build();
488
489        let ssl = Ssl::new(tls_parms.context()).unwrap();
490        let mut tls_stream = SslStream::new(ssl, stream).unwrap();
491
492        SslStream::connect(Pin::new(&mut tls_stream)).await.unwrap();
493
494        tls_stream
495    }
496
497    fn dns_name() -> &'static str {
498        "localhost"
499    }
500}