hyper_server/tls_openssl/
mod.rs1use self::future::OpenSSLAcceptorFuture;
30use crate::{
31 accept::{Accept, DefaultAcceptor},
32 server::Server,
33};
34use openssl::ssl::{
35 Error as OpenSSLError, SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod,
36};
37use std::{convert::TryFrom, fmt, net::SocketAddr, path::Path, sync::Arc, time::Duration};
38use tokio::io::{AsyncRead, AsyncWrite};
39use tokio_openssl::SslStream;
40
41pub mod future;
42
43pub fn bind_openssl(addr: SocketAddr, config: OpenSSLConfig) -> Server<OpenSSLAcceptor> {
56 let acceptor = OpenSSLAcceptor::new(config);
57 Server::bind(addr).acceptor(acceptor)
58}
59
60#[derive(Clone)]
69pub struct OpenSSLAcceptor<A = DefaultAcceptor> {
70 inner: A,
71 config: OpenSSLConfig,
72 handshake_timeout: Duration,
73}
74
75impl OpenSSLAcceptor {
76 pub fn new(config: OpenSSLConfig) -> Self {
82 let inner = DefaultAcceptor::new();
83
84 #[cfg(not(test))]
86 let handshake_timeout = Duration::from_secs(10);
87
88 #[cfg(test)]
90 let handshake_timeout = Duration::from_secs(1);
91
92 Self {
93 inner,
94 config,
95 handshake_timeout,
96 }
97 }
98
99 pub fn handshake_timeout(mut self, val: Duration) -> Self {
109 self.handshake_timeout = val;
110 self
111 }
112}
113
114impl<A, I, S> Accept<I, S> for OpenSSLAcceptor<A>
115where
116 A: Accept<I, S>,
117 A::Stream: AsyncRead + AsyncWrite + Unpin,
118{
119 type Stream = SslStream<A::Stream>;
120 type Service = A::Service;
121 type Future = OpenSSLAcceptorFuture<A::Future, A::Stream, A::Service>;
122
123 fn accept(&self, stream: I, service: S) -> Self::Future {
125 let inner_future = self.inner.accept(stream, service);
126 let config = self.config.clone();
127
128 OpenSSLAcceptorFuture::new(inner_future, config, self.handshake_timeout)
129 }
130}
131
132impl<A> fmt::Debug for OpenSSLAcceptor<A> {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("OpenSSLAcceptor").finish()
135 }
136}
137
138#[derive(Clone)]
142pub struct OpenSSLConfig {
143 acceptor: Arc<SslAcceptor>,
144}
145
146impl OpenSSLConfig {
147 pub fn from_pem_file<A: AsRef<Path>, B: AsRef<Path>>(
158 cert: A,
159 key: B,
160 ) -> Result<Self, OpenSSLError> {
161 let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
162
163 tls_builder.set_certificate_file(cert, SslFiletype::PEM)?;
164 tls_builder.set_private_key_file(key, SslFiletype::PEM)?;
165 tls_builder.check_private_key()?;
166
167 let acceptor = Arc::new(tls_builder.build());
168
169 Ok(OpenSSLConfig { acceptor })
170 }
171
172 pub fn from_pem_chain_file<A: AsRef<Path>, B: AsRef<Path>>(
183 chain: A,
184 key: B,
185 ) -> Result<Self, OpenSSLError> {
186 let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
187
188 tls_builder.set_certificate_chain_file(chain)?;
189 tls_builder.set_private_key_file(key, SslFiletype::PEM)?;
190 tls_builder.check_private_key()?;
191
192 let acceptor = Arc::new(tls_builder.build());
193
194 Ok(OpenSSLConfig { acceptor })
195 }
196}
197
198impl TryFrom<SslAcceptorBuilder> for OpenSSLConfig {
199 type Error = OpenSSLError;
200
201 fn try_from(tls_builder: SslAcceptorBuilder) -> Result<Self, Self::Error> {
222 tls_builder.check_private_key()?;
223 let acceptor = Arc::new(tls_builder.build());
224 Ok(OpenSSLConfig { acceptor })
225 }
226}
227
228impl fmt::Debug for OpenSSLConfig {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 f.debug_struct("OpenSSLConfig").finish()
231 }
232}
233
234#[cfg(test)]
235pub(crate) mod tests {
236 use crate::{
237 handle::Handle,
238 tls_openssl::{self, OpenSSLConfig},
239 };
240 use axum::{routing::get, Router};
241 use bytes::Bytes;
242 use http::{response, Request};
243 use hyper::{
244 client::conn::{handshake, SendRequest},
245 Body,
246 };
247 use std::{io, net::SocketAddr, time::Duration};
248 use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
249 use tower::{Service, ServiceExt};
250
251 use openssl::ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode};
252 use std::pin::Pin;
253 use tokio_openssl::SslStream;
254
255 #[tokio::test]
256 async fn start_and_request() {
257 let (_handle, _server_task, addr) = start_server().await;
258
259 let (mut client, _conn) = connect(addr).await;
260
261 let (_parts, body) = send_empty_request(&mut client).await;
262
263 assert_eq!(body.as_ref(), b"Hello, world!");
264 }
265
266 #[tokio::test]
267 async fn test_shutdown() {
268 let (handle, _server_task, addr) = start_server().await;
269
270 let (mut client, conn) = connect(addr).await;
271
272 handle.shutdown();
273
274 let response_future_result = client
275 .ready()
276 .await
277 .unwrap()
278 .call(Request::new(Body::empty()))
279 .await;
280
281 assert!(response_future_result.is_err());
282
283 let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
285 }
286
287 #[tokio::test]
288 async fn test_graceful_shutdown() {
289 let (handle, server_task, addr) = start_server().await;
290
291 let (mut client, conn) = connect(addr).await;
292
293 handle.graceful_shutdown(None);
294
295 let (_parts, body) = send_empty_request(&mut client).await;
296
297 assert_eq!(body.as_ref(), b"Hello, world!");
298
299 conn.abort();
301
302 let server_result = timeout(Duration::from_secs(1), server_task)
304 .await
305 .unwrap()
306 .unwrap();
307
308 assert!(server_result.is_ok());
309 }
310
311 #[tokio::test]
312 async fn test_graceful_shutdown_timed() {
313 let (handle, server_task, addr) = start_server().await;
314
315 let (mut client, _conn) = connect(addr).await;
316
317 handle.graceful_shutdown(Some(Duration::from_millis(250)));
318
319 let (_parts, body) = send_empty_request(&mut client).await;
320
321 assert_eq!(body.as_ref(), b"Hello, world!");
322
323 let server_result = timeout(Duration::from_secs(1), server_task)
328 .await
329 .unwrap()
330 .unwrap();
331
332 assert!(server_result.is_ok());
333 }
334
335 async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
336 let handle = Handle::new();
337
338 let server_handle = handle.clone();
339 let server_task = tokio::spawn(async move {
340 let app = Router::new().route("/", get(|| async { "Hello, world!" }));
341
342 let config = OpenSSLConfig::from_pem_file(
343 "examples/self-signed-certs/cert.pem",
344 "examples/self-signed-certs/key.pem",
345 )
346 .unwrap();
347
348 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
349
350 tls_openssl::bind_openssl(addr, config)
351 .handle(server_handle)
352 .serve(app.into_make_service())
353 .await
354 });
355
356 let addr = handle.listening().await.unwrap();
357
358 (handle, server_task, addr)
359 }
360
361 async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
362 let stream = TcpStream::connect(addr).await.unwrap();
363 let tls_stream = tls_connector(dns_name(), stream).await;
364
365 let (send_request, connection) = handshake(tls_stream).await.unwrap();
366
367 let task = tokio::spawn(async move {
368 let _ = connection.await;
369 });
370
371 (send_request, task)
372 }
373
374 async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
375 let (parts, body) = client
376 .ready()
377 .await
378 .unwrap()
379 .call(Request::new(Body::empty()))
380 .await
381 .unwrap()
382 .into_parts();
383 let body = hyper::body::to_bytes(body).await.unwrap();
384
385 (parts, body)
386 }
387
388 pub(crate) async fn tls_connector(hostname: &str, stream: TcpStream) -> SslStream<TcpStream> {
390 let mut tls_parms = SslConnector::builder(SslMethod::tls_client()).unwrap();
391 tls_parms.set_verify(SslVerifyMode::NONE);
392 let hostname_owned = hostname.to_string();
393 tls_parms.set_client_hello_callback(move |ssl_ref, _ssl_alert| {
394 ssl_ref
395 .set_hostname(hostname_owned.as_str())
396 .map(|()| openssl::ssl::ClientHelloResponse::SUCCESS)
397 });
398 let tls_parms = tls_parms.build();
399
400 let ssl = Ssl::new(tls_parms.context()).unwrap();
401 let mut tls_stream = SslStream::new(ssl, stream).unwrap();
402
403 SslStream::connect(Pin::new(&mut tls_stream)).await.unwrap();
404
405 tls_stream
406 }
407
408 pub(crate) fn dns_name() -> &'static str {
410 "localhost"
411 }
412}