hyper_server/tls_openssl/
mod.rs

1//! Tls implementation using [`openssl`]
2//!
3//! # Example
4//!
5//! ```rust,no_run
6//! use axum::{routing::get, Router};
7//! use hyper_server::tls_openssl::OpenSSLConfig;
8//! use std::net::SocketAddr;
9//!
10//! #[tokio::main]
11//! async fn main() {
12//!     let app = Router::new().route("/", get(|| async { "Hello, world!" }));
13//!
14//!     let config = OpenSSLConfig::from_pem_file(
15//!         "examples/self-signed-certs/cert.pem",
16//!         "examples/self-signed-certs/key.pem",
17//!     )
18//!     .unwrap();
19//!
20//!     let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
21//!     println!("listening on {}", addr);
22//!     hyper_server::bind_openssl(addr, config)
23//!         .serve(app.into_make_service())
24//!         .await
25//!         .unwrap();
26//! }
27//! ```
28
29use self::future::OpenSSLAcceptorFuture;
30use crate::{
31    accept::{Accept, DefaultAcceptor},
32    server::Server,
33};
34use openssl::ssl::{
35    Error as OpenSSLError, SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod,
36};
37use std::{convert::TryFrom, fmt, net::SocketAddr, path::Path, sync::Arc, time::Duration};
38use tokio::io::{AsyncRead, AsyncWrite};
39use tokio_openssl::SslStream;
40
41pub mod future;
42
43/// Binds a TLS server using OpenSSL to the specified address with the given configuration.
44///
45/// The server is configured to accept TLS encrypted connections.
46///
47/// # Arguments
48///
49/// * `addr`: The address to which the server will bind.
50/// * `config`: The TLS configuration for the server.
51///
52/// # Returns
53///
54/// A configured `Server` instance ready to be run.
55pub fn bind_openssl(addr: SocketAddr, config: OpenSSLConfig) -> Server<OpenSSLAcceptor> {
56    let acceptor = OpenSSLAcceptor::new(config);
57    Server::bind(addr).acceptor(acceptor)
58}
59
60/// Represents a TLS acceptor that uses OpenSSL for cryptographic operations.
61///
62/// This structure is used for handling TLS encrypted connections.
63///
64/// The acceptor is backed by OpenSSL, and is used to upgrade incoming non-secure connections
65/// to secure TLS connections.
66///
67/// The default TLS handshake timeout is set to 10 seconds.
68#[derive(Clone)]
69pub struct OpenSSLAcceptor<A = DefaultAcceptor> {
70    inner: A,
71    config: OpenSSLConfig,
72    handshake_timeout: Duration,
73}
74
75impl OpenSSLAcceptor {
76    /// Constructs a new instance of the OpenSSL acceptor.
77    ///
78    /// # Arguments
79    ///
80    /// * `config`: Configuration options for the OpenSSL server.
81    pub fn new(config: OpenSSLConfig) -> Self {
82        let inner = DefaultAcceptor::new();
83
84        // Default handshake timeout is 10 seconds.
85        #[cfg(not(test))]
86        let handshake_timeout = Duration::from_secs(10);
87
88        // For tests, use a shorter timeout to avoid unnecessary delays.
89        #[cfg(test)]
90        let handshake_timeout = Duration::from_secs(1);
91
92        Self {
93            inner,
94            config,
95            handshake_timeout,
96        }
97    }
98
99    /// Overrides the default TLS handshake timeout.
100    ///
101    /// # Arguments
102    ///
103    /// * `val`: The duration to set as the new handshake timeout.
104    ///
105    /// # Returns
106    ///
107    /// A modified version of the current acceptor with the new timeout value.
108    pub fn handshake_timeout(mut self, val: Duration) -> Self {
109        self.handshake_timeout = val;
110        self
111    }
112}
113
114impl<A, I, S> Accept<I, S> for OpenSSLAcceptor<A>
115where
116    A: Accept<I, S>,
117    A::Stream: AsyncRead + AsyncWrite + Unpin,
118{
119    type Stream = SslStream<A::Stream>;
120    type Service = A::Service;
121    type Future = OpenSSLAcceptorFuture<A::Future, A::Stream, A::Service>;
122
123    /// Handles the incoming stream, initiates a TLS handshake, and upgrades it to a secure connection.
124    fn accept(&self, stream: I, service: S) -> Self::Future {
125        let inner_future = self.inner.accept(stream, service);
126        let config = self.config.clone();
127
128        OpenSSLAcceptorFuture::new(inner_future, config, self.handshake_timeout)
129    }
130}
131
132impl<A> fmt::Debug for OpenSSLAcceptor<A> {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        f.debug_struct("OpenSSLAcceptor").finish()
135    }
136}
137
138/// Represents configuration options for an OpenSSL-based server.
139///
140/// This configuration is used when constructing a new `OpenSSLAcceptor`.
141#[derive(Clone)]
142pub struct OpenSSLConfig {
143    acceptor: Arc<SslAcceptor>,
144}
145
146impl OpenSSLConfig {
147    /// Creates a new configuration using a PEM formatted certificate and key.
148    ///
149    /// # Arguments
150    ///
151    /// * `cert`: Path to the PEM-formatted certificate file.
152    /// * `key`: Path to the PEM-formatted private key file.
153    ///
154    /// # Returns
155    ///
156    /// A `Result` that contains an `OpenSSLConfig` or an `OpenSSLError`.
157    pub fn from_pem_file<A: AsRef<Path>, B: AsRef<Path>>(
158        cert: A,
159        key: B,
160    ) -> Result<Self, OpenSSLError> {
161        let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
162
163        tls_builder.set_certificate_file(cert, SslFiletype::PEM)?;
164        tls_builder.set_private_key_file(key, SslFiletype::PEM)?;
165        tls_builder.check_private_key()?;
166
167        let acceptor = Arc::new(tls_builder.build());
168
169        Ok(OpenSSLConfig { acceptor })
170    }
171
172    /// Creates a new configuration using a PEM formatted certificate chain and key.
173    ///
174    /// # Arguments
175    ///
176    /// * `chain`: Path to the PEM-formatted certificate chain file.
177    /// * `key`: Path to the PEM-formatted private key file.
178    ///
179    /// # Returns
180    ///
181    /// A `Result` that contains an `OpenSSLConfig` or an `OpenSSLError`.
182    pub fn from_pem_chain_file<A: AsRef<Path>, B: AsRef<Path>>(
183        chain: A,
184        key: B,
185    ) -> Result<Self, OpenSSLError> {
186        let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
187
188        tls_builder.set_certificate_chain_file(chain)?;
189        tls_builder.set_private_key_file(key, SslFiletype::PEM)?;
190        tls_builder.check_private_key()?;
191
192        let acceptor = Arc::new(tls_builder.build());
193
194        Ok(OpenSSLConfig { acceptor })
195    }
196}
197
198impl TryFrom<SslAcceptorBuilder> for OpenSSLConfig {
199    type Error = OpenSSLError;
200
201    /// Constructs [`OpenSSLConfig`] from an [`SslAcceptorBuilder`]. This allows precise
202    /// control over the settings that will be used by OpenSSL in this server.
203    ///
204    /// # Example
205    /// ```
206    /// use hyper_server::tls_openssl::OpenSSLConfig;
207    /// use openssl::ssl::{SslAcceptor, SslMethod};
208    /// use std::convert::TryFrom;
209    ///
210    /// #[tokio::main]
211    /// async fn main() {
212    ///     let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())
213    ///         .unwrap();
214    ///     // Set configurations like set_certificate_chain_file or
215    ///     // set_private_key_file.
216    ///     // let tls_builder.set_ ... ;
217
218    ///     let _config = OpenSSLConfig::try_from(tls_builder);
219    /// }
220    /// ```
221    fn try_from(tls_builder: SslAcceptorBuilder) -> Result<Self, Self::Error> {
222        tls_builder.check_private_key()?;
223        let acceptor = Arc::new(tls_builder.build());
224        Ok(OpenSSLConfig { acceptor })
225    }
226}
227
228impl fmt::Debug for OpenSSLConfig {
229    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230        f.debug_struct("OpenSSLConfig").finish()
231    }
232}
233
234#[cfg(test)]
235pub(crate) mod tests {
236    use crate::{
237        handle::Handle,
238        tls_openssl::{self, OpenSSLConfig},
239    };
240    use axum::{routing::get, Router};
241    use bytes::Bytes;
242    use http::{response, Request};
243    use hyper::{
244        client::conn::{handshake, SendRequest},
245        Body,
246    };
247    use std::{io, net::SocketAddr, time::Duration};
248    use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
249    use tower::{Service, ServiceExt};
250
251    use openssl::ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode};
252    use std::pin::Pin;
253    use tokio_openssl::SslStream;
254
255    #[tokio::test]
256    async fn start_and_request() {
257        let (_handle, _server_task, addr) = start_server().await;
258
259        let (mut client, _conn) = connect(addr).await;
260
261        let (_parts, body) = send_empty_request(&mut client).await;
262
263        assert_eq!(body.as_ref(), b"Hello, world!");
264    }
265
266    #[tokio::test]
267    async fn test_shutdown() {
268        let (handle, _server_task, addr) = start_server().await;
269
270        let (mut client, conn) = connect(addr).await;
271
272        handle.shutdown();
273
274        let response_future_result = client
275            .ready()
276            .await
277            .unwrap()
278            .call(Request::new(Body::empty()))
279            .await;
280
281        assert!(response_future_result.is_err());
282
283        // Connection task should finish soon.
284        let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
285    }
286
287    #[tokio::test]
288    async fn test_graceful_shutdown() {
289        let (handle, server_task, addr) = start_server().await;
290
291        let (mut client, conn) = connect(addr).await;
292
293        handle.graceful_shutdown(None);
294
295        let (_parts, body) = send_empty_request(&mut client).await;
296
297        assert_eq!(body.as_ref(), b"Hello, world!");
298
299        // Disconnect client.
300        conn.abort();
301
302        // Server task should finish soon.
303        let server_result = timeout(Duration::from_secs(1), server_task)
304            .await
305            .unwrap()
306            .unwrap();
307
308        assert!(server_result.is_ok());
309    }
310
311    #[tokio::test]
312    async fn test_graceful_shutdown_timed() {
313        let (handle, server_task, addr) = start_server().await;
314
315        let (mut client, _conn) = connect(addr).await;
316
317        handle.graceful_shutdown(Some(Duration::from_millis(250)));
318
319        let (_parts, body) = send_empty_request(&mut client).await;
320
321        assert_eq!(body.as_ref(), b"Hello, world!");
322
323        // Don't disconnect client.
324        // conn.abort();
325
326        // Server task should finish soon.
327        let server_result = timeout(Duration::from_secs(1), server_task)
328            .await
329            .unwrap()
330            .unwrap();
331
332        assert!(server_result.is_ok());
333    }
334
335    async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
336        let handle = Handle::new();
337
338        let server_handle = handle.clone();
339        let server_task = tokio::spawn(async move {
340            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
341
342            let config = OpenSSLConfig::from_pem_file(
343                "examples/self-signed-certs/cert.pem",
344                "examples/self-signed-certs/key.pem",
345            )
346            .unwrap();
347
348            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
349
350            tls_openssl::bind_openssl(addr, config)
351                .handle(server_handle)
352                .serve(app.into_make_service())
353                .await
354        });
355
356        let addr = handle.listening().await.unwrap();
357
358        (handle, server_task, addr)
359    }
360
361    async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
362        let stream = TcpStream::connect(addr).await.unwrap();
363        let tls_stream = tls_connector(dns_name(), stream).await;
364
365        let (send_request, connection) = handshake(tls_stream).await.unwrap();
366
367        let task = tokio::spawn(async move {
368            let _ = connection.await;
369        });
370
371        (send_request, task)
372    }
373
374    async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
375        let (parts, body) = client
376            .ready()
377            .await
378            .unwrap()
379            .call(Request::new(Body::empty()))
380            .await
381            .unwrap()
382            .into_parts();
383        let body = hyper::body::to_bytes(body).await.unwrap();
384
385        (parts, body)
386    }
387
388    /// Used in `proxy-protocol` feature tests.
389    pub(crate) async fn tls_connector(hostname: &str, stream: TcpStream) -> SslStream<TcpStream> {
390        let mut tls_parms = SslConnector::builder(SslMethod::tls_client()).unwrap();
391        tls_parms.set_verify(SslVerifyMode::NONE);
392        let hostname_owned = hostname.to_string();
393        tls_parms.set_client_hello_callback(move |ssl_ref, _ssl_alert| {
394            ssl_ref
395                .set_hostname(hostname_owned.as_str())
396                .map(|()| openssl::ssl::ClientHelloResponse::SUCCESS)
397        });
398        let tls_parms = tls_parms.build();
399
400        let ssl = Ssl::new(tls_parms.context()).unwrap();
401        let mut tls_stream = SslStream::new(ssl, stream).unwrap();
402
403        SslStream::connect(Pin::new(&mut tls_stream)).await.unwrap();
404
405        tls_stream
406    }
407
408    /// Used in `proxy-protocol` feature tests.
409    pub(crate) fn dns_name() -> &'static str {
410        "localhost"
411    }
412}