hyper_serve/tls_openssl/
mod.rs

1//! Tls implementation using [`openssl`]
2//!
3//! # Example
4//!
5//! ```rust,no_run
6//! use axum::{routing::get, Router};
7//! use hyper_serve::tls_openssl::OpenSSLConfig;
8//! use std::net::SocketAddr;
9//!
10//! #[tokio::main]
11//! async fn main() {
12//!     let app = Router::new().route("/", get(|| async { "Hello, world!" }));
13//!
14//!     let config = OpenSSLConfig::from_pem_file(
15//!         "examples/self-signed-certs/cert.pem",
16//!         "examples/self-signed-certs/key.pem",
17//!     )
18//!     .unwrap();
19//!
20//!     let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
21//!     println!("listening on {}", addr);
22//!     hyper_serve::bind_openssl(addr, config)
23//!         .serve(app.into_make_service())
24//!         .await
25//!         .unwrap();
26//! }
27//! ```
28
29use self::future::OpenSSLAcceptorFuture;
30use crate::{
31    accept::{Accept, DefaultAcceptor},
32    server::Server,
33};
34use arc_swap::ArcSwap;
35use openssl::{
36    pkey::PKey,
37    ssl::{
38        self, AlpnError, Error as OpenSSLError, SslAcceptor, SslAcceptorBuilder, SslFiletype,
39        SslMethod, SslRef,
40    },
41    x509::X509,
42};
43use std::{convert::TryFrom, fmt, net::SocketAddr, path::Path, sync::Arc, time::Duration};
44use tokio::io::{AsyncRead, AsyncWrite};
45use tokio_openssl::SslStream;
46
47pub mod future;
48
49/// Create a TLS server that will be bound to the provided socket with a configuration. See
50/// the [`crate::tls_openssl`] module for more details.
51pub fn bind_openssl(addr: SocketAddr, config: OpenSSLConfig) -> Server<OpenSSLAcceptor> {
52    let acceptor = OpenSSLAcceptor::new(config);
53
54    Server::bind(addr).acceptor(acceptor)
55}
56
57/// Tls acceptor that uses OpenSSL. For details on how to use this see [`crate::tls_openssl`] module
58/// for more details.
59#[derive(Clone)]
60pub struct OpenSSLAcceptor<A = DefaultAcceptor> {
61    inner: A,
62    config: OpenSSLConfig,
63    handshake_timeout: Duration,
64}
65
66impl OpenSSLAcceptor {
67    /// Create a new OpenSSL acceptor based on the provided [`OpenSSLConfig`]. This is
68    /// generally used with manual calls to [`Server::bind`]. You may want [`bind_openssl`]
69    /// instead.
70    pub fn new(config: OpenSSLConfig) -> Self {
71        let inner = DefaultAcceptor::new();
72
73        #[cfg(not(test))]
74        let handshake_timeout = Duration::from_secs(10);
75
76        // Don't force tests to wait too long.
77        #[cfg(test)]
78        let handshake_timeout = Duration::from_secs(1);
79
80        Self {
81            inner,
82            config,
83            handshake_timeout,
84        }
85    }
86
87    /// Override the default TLS handshake timeout of 10 seconds.
88    pub fn handshake_timeout(mut self, val: Duration) -> Self {
89        self.handshake_timeout = val;
90        self
91    }
92}
93
94impl<A> OpenSSLAcceptor<A> {
95    /// Overwrite inner acceptor.
96    pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> OpenSSLAcceptor<Acceptor> {
97        OpenSSLAcceptor {
98            inner: acceptor,
99            config: self.config,
100            handshake_timeout: self.handshake_timeout,
101        }
102    }
103}
104
105impl<A, I, S> Accept<I, S> for OpenSSLAcceptor<A>
106where
107    A: Accept<I, S>,
108    A::Stream: AsyncRead + AsyncWrite + Unpin,
109{
110    type Stream = SslStream<A::Stream>;
111    type Service = A::Service;
112    type Future = OpenSSLAcceptorFuture<A::Future, A::Stream, A::Service>;
113
114    fn accept(&self, stream: I, service: S) -> Self::Future {
115        let inner_future = self.inner.accept(stream, service);
116        let config = self.config.clone();
117
118        OpenSSLAcceptorFuture::new(inner_future, config, self.handshake_timeout)
119    }
120}
121
122impl<A> fmt::Debug for OpenSSLAcceptor<A> {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        f.debug_struct("OpenSSLAcceptor").finish()
125    }
126}
127
128/// OpenSSL configuration.
129#[derive(Clone)]
130pub struct OpenSSLConfig {
131    acceptor: Arc<ArcSwap<SslAcceptor>>,
132}
133
134impl OpenSSLConfig {
135    /// Create config from `Arc<`[`SslAcceptor`]`>`.
136    pub fn from_acceptor(acceptor: Arc<SslAcceptor>) -> Self {
137        let acceptor = Arc::new(ArcSwap::new(acceptor));
138
139        OpenSSLConfig { acceptor }
140    }
141
142    /// This helper will establish a TLS server based on strong cipher suites
143    /// from a DER-encoded certificate and key.
144    pub fn from_der(cert: &[u8], key: &[u8]) -> Result<Self, OpenSSLError> {
145        let acceptor = Arc::new(ArcSwap::from_pointee(config_from_der(cert, key)?));
146
147        Ok(OpenSSLConfig { acceptor })
148    }
149
150    /// This helper will establish a TLS server based on strong cipher suites
151    /// from a PEM-formatted certificate and key.
152    pub fn from_pem(cert: &[u8], key: &[u8]) -> Result<Self, OpenSSLError> {
153        let acceptor = Arc::new(ArcSwap::from_pointee(config_from_pem(cert, key)?));
154
155        Ok(OpenSSLConfig { acceptor })
156    }
157
158    /// This helper will establish a TLS server based on strong cipher suites
159    /// from a PEM-formatted certificate and key.
160    pub fn from_pem_file(
161        cert: impl AsRef<Path>,
162        key: impl AsRef<Path>,
163    ) -> Result<Self, OpenSSLError> {
164        let acceptor = Arc::new(ArcSwap::from_pointee(config_from_pem_file(cert, key)?));
165
166        Ok(OpenSSLConfig { acceptor })
167    }
168
169    /// This helper will establish a TLS server based on strong cipher suites
170    /// from a PEM-formatted certificate chain and key.
171    pub fn from_pem_chain_file(
172        chain: impl AsRef<Path>,
173        key: impl AsRef<Path>,
174    ) -> Result<Self, OpenSSLError> {
175        let acceptor = Arc::new(ArcSwap::from_pointee(config_from_pem_chain_file(
176            chain, key,
177        )?));
178
179        Ok(OpenSSLConfig { acceptor })
180    }
181
182    /// Get inner `Arc<`[`SslAcceptor`]`>`.
183    pub fn get_inner(&self) -> Arc<SslAcceptor> {
184        self.acceptor.load_full()
185    }
186
187    /// Reload acceptor from `Arc<`[`SslAcceptor`]`>`.
188    pub fn reload_from_acceptor(&self, acceptor: Arc<SslAcceptor>) {
189        self.acceptor.store(acceptor);
190    }
191
192    /// Reload acceptor from a DER-encoded certificate and key.
193    pub fn reload_from_der(&self, cert: &[u8], key: &[u8]) -> Result<(), OpenSSLError> {
194        let acceptor = Arc::new(config_from_der(cert, key)?);
195        self.acceptor.store(acceptor);
196
197        Ok(())
198    }
199
200    /// Reload acceptor from a PEM-formatted certificate and key.
201    pub fn reload_from_pem(&self, cert: &[u8], key: &[u8]) -> Result<(), OpenSSLError> {
202        let acceptor = Arc::new(config_from_pem(cert, key)?);
203        self.acceptor.store(acceptor);
204
205        Ok(())
206    }
207
208    /// Reload acceptor from a PEM-formatted certificate and key.
209    pub fn reload_from_pem_file(
210        &self,
211        cert: impl AsRef<Path>,
212        key: impl AsRef<Path>,
213    ) -> Result<(), OpenSSLError> {
214        let acceptor = Arc::new(config_from_pem_file(cert, key)?);
215        self.acceptor.store(acceptor);
216
217        Ok(())
218    }
219
220    /// Reload acceptor from a PEM-formatted certificate chain and key.
221    pub fn reload_from_pem_chain_file(
222        &self,
223        chain: impl AsRef<Path>,
224        key: impl AsRef<Path>,
225    ) -> Result<(), OpenSSLError> {
226        let acceptor = Arc::new(config_from_pem_chain_file(chain, key)?);
227        self.acceptor.store(acceptor);
228
229        Ok(())
230    }
231}
232
233impl TryFrom<SslAcceptorBuilder> for OpenSSLConfig {
234    type Error = OpenSSLError;
235
236    /// Build the [`OpenSSLConfig`] from an [`SslAcceptorBuilder`]. This allows precise
237    /// control over the settings that will be used by OpenSSL in this server.
238    ///
239    /// # Example
240    /// ```
241    /// use hyper_serve::tls_openssl::OpenSSLConfig;
242    /// use openssl::ssl::{SslAcceptor, SslMethod};
243    /// use std::convert::TryFrom;
244    ///
245    /// #[tokio::main]
246    /// async fn main() {
247    ///     let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())
248    ///         .unwrap();
249    ///     // Set configurations like set_certificate_chain_file or
250    ///     // set_private_key_file.
251    ///     // let tls_builder.set_ ... ;
252
253    ///     let _config = OpenSSLConfig::try_from(tls_builder);
254    /// }
255    /// ```
256    fn try_from(mut tls_builder: SslAcceptorBuilder) -> Result<Self, Self::Error> {
257        // Any other checks?
258        tls_builder.check_private_key()?;
259        tls_builder.set_alpn_select_callback(alpn_select);
260
261        let acceptor = Arc::new(ArcSwap::from_pointee(tls_builder.build()));
262
263        Ok(OpenSSLConfig { acceptor })
264    }
265}
266
267impl fmt::Debug for OpenSSLConfig {
268    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269        f.debug_struct("OpenSSLConfig").finish()
270    }
271}
272
273fn alpn_select<'a>(_tls: &mut SslRef, client: &'a [u8]) -> Result<&'a [u8], AlpnError> {
274    ssl::select_next_proto(b"\x02h2\x08http/1.1", client).ok_or(AlpnError::NOACK)
275}
276
277fn config_from_der(cert: &[u8], key: &[u8]) -> Result<SslAcceptor, OpenSSLError> {
278    let cert = X509::from_der(cert)?;
279    let key = PKey::private_key_from_der(key)?;
280
281    let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
282    tls_builder.set_certificate(&cert)?;
283    tls_builder.set_private_key(&key)?;
284    tls_builder.check_private_key()?;
285    tls_builder.set_alpn_select_callback(alpn_select);
286
287    let acceptor = tls_builder.build();
288    Ok(acceptor)
289}
290
291fn config_from_pem(cert: &[u8], key: &[u8]) -> Result<SslAcceptor, OpenSSLError> {
292    let cert = X509::from_pem(cert)?;
293    let key = PKey::private_key_from_pem(key)?;
294
295    let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
296    tls_builder.set_certificate(&cert)?;
297    tls_builder.set_private_key(&key)?;
298    tls_builder.check_private_key()?;
299    tls_builder.set_alpn_select_callback(alpn_select);
300
301    let acceptor = tls_builder.build();
302    Ok(acceptor)
303}
304
305fn config_from_pem_file(
306    cert: impl AsRef<Path>,
307    key: impl AsRef<Path>,
308) -> Result<SslAcceptor, OpenSSLError> {
309    let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
310    tls_builder.set_certificate_file(cert, SslFiletype::PEM)?;
311    tls_builder.set_private_key_file(key, SslFiletype::PEM)?;
312    tls_builder.check_private_key()?;
313    tls_builder.set_alpn_select_callback(alpn_select);
314
315    let acceptor = tls_builder.build();
316    Ok(acceptor)
317}
318
319fn config_from_pem_chain_file(
320    chain: impl AsRef<Path>,
321    key: impl AsRef<Path>,
322) -> Result<SslAcceptor, OpenSSLError> {
323    let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
324    tls_builder.set_certificate_chain_file(chain)?;
325    tls_builder.set_private_key_file(key, SslFiletype::PEM)?;
326    tls_builder.check_private_key()?;
327    tls_builder.set_alpn_select_callback(alpn_select);
328
329    let acceptor = tls_builder.build();
330    Ok(acceptor)
331}
332
333#[cfg(test)]
334mod tests {
335    use crate::{
336        handle::Handle,
337        tls_openssl::{self, OpenSSLConfig},
338    };
339    use axum::body::Body;
340    use axum::routing::{get, post};
341    use axum::Router;
342    use bytes::Bytes;
343    use http::{response, Request};
344    use http_body_util::BodyExt;
345    use hyper::client::conn::http1::{handshake, SendRequest};
346    use hyper_util::rt::TokioIo;
347    use std::{io, net::SocketAddr, time::Duration};
348    use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
349
350    use crate::server::tests::slow_body;
351    use axum::response::Response;
352    use openssl::{
353        ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode},
354        x509::X509,
355    };
356    use std::pin::Pin;
357    use tokio::sync::oneshot;
358    use tokio_openssl::SslStream;
359
360    #[tokio::test]
361    async fn start_and_request() {
362        let (_handle, _server_task, addr) = start_server().await;
363
364        let (mut client, _conn) = connect(addr).await;
365
366        let (_parts, body) = send_empty_request(&mut client).await;
367
368        assert_eq!(body.as_ref(), b"Hello, world!");
369    }
370
371    #[tokio::test]
372    async fn test_reload() {
373        let handle = Handle::new();
374
375        let config = OpenSSLConfig::from_pem_file(
376            "examples/self-signed-certs/cert.pem",
377            "examples/self-signed-certs/key.pem",
378        )
379        .unwrap();
380
381        let server_handle = handle.clone();
382        let openssl_config = config.clone();
383        tokio::spawn(async move {
384            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
385
386            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
387
388            tls_openssl::bind_openssl(addr, openssl_config)
389                .handle(server_handle)
390                .serve(app.into_make_service())
391                .await
392        });
393
394        let addr = handle.listening().await.unwrap();
395
396        let cert_a = get_first_cert(addr).await;
397        let mut cert_b = get_first_cert(addr).await;
398
399        assert_eq!(cert_a, cert_b);
400
401        config
402            .reload_from_pem_file(
403                "examples/self-signed-certs/reload/cert.pem",
404                "examples/self-signed-certs/reload/key.pem",
405            )
406            .unwrap();
407
408        cert_b = get_first_cert(addr).await;
409
410        assert_ne!(cert_a, cert_b);
411
412        config
413            .reload_from_pem_file(
414                "examples/self-signed-certs/cert.pem",
415                "examples/self-signed-certs/key.pem",
416            )
417            .unwrap();
418
419        cert_b = get_first_cert(addr).await;
420
421        assert_eq!(cert_a, cert_b);
422    }
423
424    #[tokio::test]
425    async fn test_shutdown() {
426        let (handle, _server_task, addr) = start_server().await;
427
428        let (mut client, conn) = connect(addr).await;
429
430        handle.shutdown();
431
432        let response_future_result = client.send_request(Request::new(Body::empty())).await;
433
434        assert!(response_future_result.is_err());
435
436        // Connection task should finish soon.
437        let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
438    }
439
440    #[tokio::test]
441    async fn test_graceful_shutdown_timeout() {
442        let (handle, server_task, addr) = start_server().await;
443
444        let (mut client1, _conn1) = connect(addr).await;
445        let (mut client2, _conn2) = connect(addr).await;
446
447        // Clients can send request before graceful shutdown.
448        crate::server::tests::do_empty_request(&mut client1)
449            .await
450            .unwrap();
451        crate::server::tests::do_empty_request(&mut client2)
452            .await
453            .unwrap();
454
455        let start = tokio::time::Instant::now();
456
457        let (hdr1_tx, hdr1_rx) = oneshot::channel::<()>();
458
459        let task1 = async {
460            // A slow request made before graceful shutdown is handled.
461            // This one is shorter than the timeout, so it should succeed.
462            let hdr1 =
463                crate::server::tests::send_slow_request(&mut client1, Duration::from_millis(222))
464                    .await;
465            hdr1_tx.send(()).unwrap();
466
467            let res1 = crate::server::tests::recv_slow_response_body(hdr1.unwrap()).await;
468            res1.unwrap();
469        };
470        let task2 = async {
471            // A slow request made before graceful shutdown is handled.
472            // This one is much longer than the timeout; it should fail sometime
473            // after the graceful shutdown timeout.
474
475            let hdr2 =
476                crate::server::tests::send_slow_request(&mut client2, Duration::from_millis(5_555))
477                    .await;
478            hdr2.unwrap_err();
479        };
480        let task3 = async {
481            // Begin graceful shutdown after we receive response headers for (1).
482            hdr1_rx.await.unwrap();
483
484            // Set a timeout on requests to finish before we drop them.
485            handle.graceful_shutdown(Some(Duration::from_millis(333)));
486
487            // Server task should finish soon.
488            timeout(Duration::from_secs(1), server_task)
489                .await
490                .unwrap()
491                .unwrap()
492                .unwrap();
493
494            // At this point, graceful shutdown must have occured.
495            assert!(start.elapsed() >= Duration::from_millis(222 + 333));
496            assert!(start.elapsed() <= Duration::from_millis(5_555));
497        };
498
499        tokio::join!(task1, task2, task3);
500    }
501
502    async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
503        let handle = Handle::new();
504
505        let server_handle = handle.clone();
506        let server_task = tokio::spawn(async move {
507            let app = Router::new()
508                .route("/", get(|| async { "Hello, world!" }))
509                .route(
510                    "/echo_slowly",
511                    post(|body: Bytes| async move {
512                        // Stream a response slowly, byte-by-byte, over 100ms
513                        Response::new(slow_body(body.len(), Duration::from_millis(100)))
514                    }),
515                );
516
517            let config = OpenSSLConfig::from_pem_file(
518                "examples/self-signed-certs/cert.pem",
519                "examples/self-signed-certs/key.pem",
520            )
521            .unwrap();
522
523            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
524
525            tls_openssl::bind_openssl(addr, config)
526                .handle(server_handle)
527                .serve(app.into_make_service())
528                .await
529        });
530
531        let addr = handle.listening().await.unwrap();
532
533        (handle, server_task, addr)
534    }
535
536    async fn get_first_cert(addr: SocketAddr) -> X509 {
537        let stream = TcpStream::connect(addr).await.unwrap();
538        let tls_stream = tls_connector(dns_name(), stream).await;
539
540        tls_stream.ssl().peer_certificate().unwrap()
541    }
542
543    async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
544        let stream = TcpStream::connect(addr).await.unwrap();
545        let tls_stream = TokioIo::new(tls_connector(dns_name(), stream).await);
546
547        let (send_request, connection) = handshake(tls_stream).await.unwrap();
548
549        let task = tokio::spawn(async move {
550            let _ = connection.await;
551        });
552
553        (send_request, task)
554    }
555
556    async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
557        let (parts, body) = client
558            .send_request(Request::new(Body::empty()))
559            .await
560            .unwrap()
561            .into_parts();
562        let body = body.collect().await.unwrap().to_bytes();
563
564        (parts, body)
565    }
566
567    async fn tls_connector(hostname: &str, stream: TcpStream) -> SslStream<TcpStream> {
568        let mut tls_parms = SslConnector::builder(SslMethod::tls_client()).unwrap();
569        tls_parms.set_verify(SslVerifyMode::NONE);
570        let hostname_owned = hostname.to_string();
571        tls_parms.set_client_hello_callback(move |ssl_ref, _ssl_alert| {
572            ssl_ref
573                .set_hostname(hostname_owned.as_str())
574                .map(|()| openssl::ssl::ClientHelloResponse::SUCCESS)
575        });
576        let tls_parms = tls_parms.build();
577
578        let ssl = Ssl::new(tls_parms.context()).unwrap();
579        let mut tls_stream = SslStream::new(ssl, stream).unwrap();
580
581        SslStream::connect(Pin::new(&mut tls_stream)).await.unwrap();
582
583        tls_stream
584    }
585
586    fn dns_name() -> &'static str {
587        "localhost"
588    }
589}