1use self::future::OpenSSLAcceptorFuture;
30use crate::{
31 accept::{Accept, DefaultAcceptor},
32 server::Server,
33};
34use arc_swap::ArcSwap;
35use openssl::{
36 pkey::PKey,
37 ssl::{
38 self, AlpnError, Error as OpenSSLError, SslAcceptor, SslAcceptorBuilder, SslFiletype,
39 SslMethod, SslRef,
40 },
41 x509::X509,
42};
43use std::{convert::TryFrom, fmt, net::SocketAddr, path::Path, sync::Arc, time::Duration};
44use tokio::io::{AsyncRead, AsyncWrite};
45use tokio_openssl::SslStream;
46
47pub mod future;
48
49pub fn bind_openssl(addr: SocketAddr, config: OpenSSLConfig) -> Server<OpenSSLAcceptor> {
52 let acceptor = OpenSSLAcceptor::new(config);
53
54 Server::bind(addr).acceptor(acceptor)
55}
56
57#[derive(Clone)]
60pub struct OpenSSLAcceptor<A = DefaultAcceptor> {
61 inner: A,
62 config: OpenSSLConfig,
63 handshake_timeout: Duration,
64}
65
66impl OpenSSLAcceptor {
67 pub fn new(config: OpenSSLConfig) -> Self {
71 let inner = DefaultAcceptor::new();
72
73 #[cfg(not(test))]
74 let handshake_timeout = Duration::from_secs(10);
75
76 #[cfg(test)]
78 let handshake_timeout = Duration::from_secs(1);
79
80 Self {
81 inner,
82 config,
83 handshake_timeout,
84 }
85 }
86
87 pub fn handshake_timeout(mut self, val: Duration) -> Self {
89 self.handshake_timeout = val;
90 self
91 }
92}
93
94impl<A> OpenSSLAcceptor<A> {
95 pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> OpenSSLAcceptor<Acceptor> {
97 OpenSSLAcceptor {
98 inner: acceptor,
99 config: self.config,
100 handshake_timeout: self.handshake_timeout,
101 }
102 }
103}
104
105impl<A, I, S> Accept<I, S> for OpenSSLAcceptor<A>
106where
107 A: Accept<I, S>,
108 A::Stream: AsyncRead + AsyncWrite + Unpin,
109{
110 type Stream = SslStream<A::Stream>;
111 type Service = A::Service;
112 type Future = OpenSSLAcceptorFuture<A::Future, A::Stream, A::Service>;
113
114 fn accept(&self, stream: I, service: S) -> Self::Future {
115 let inner_future = self.inner.accept(stream, service);
116 let config = self.config.clone();
117
118 OpenSSLAcceptorFuture::new(inner_future, config, self.handshake_timeout)
119 }
120}
121
122impl<A> fmt::Debug for OpenSSLAcceptor<A> {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("OpenSSLAcceptor").finish()
125 }
126}
127
128#[derive(Clone)]
130pub struct OpenSSLConfig {
131 acceptor: Arc<ArcSwap<SslAcceptor>>,
132}
133
134impl OpenSSLConfig {
135 pub fn from_acceptor(acceptor: Arc<SslAcceptor>) -> Self {
137 let acceptor = Arc::new(ArcSwap::new(acceptor));
138
139 OpenSSLConfig { acceptor }
140 }
141
142 pub fn from_der(cert: &[u8], key: &[u8]) -> Result<Self, OpenSSLError> {
145 let acceptor = Arc::new(ArcSwap::from_pointee(config_from_der(cert, key)?));
146
147 Ok(OpenSSLConfig { acceptor })
148 }
149
150 pub fn from_pem(cert: &[u8], key: &[u8]) -> Result<Self, OpenSSLError> {
153 let acceptor = Arc::new(ArcSwap::from_pointee(config_from_pem(cert, key)?));
154
155 Ok(OpenSSLConfig { acceptor })
156 }
157
158 pub fn from_pem_file(
161 cert: impl AsRef<Path>,
162 key: impl AsRef<Path>,
163 ) -> Result<Self, OpenSSLError> {
164 let acceptor = Arc::new(ArcSwap::from_pointee(config_from_pem_file(cert, key)?));
165
166 Ok(OpenSSLConfig { acceptor })
167 }
168
169 pub fn from_pem_chain_file(
172 chain: impl AsRef<Path>,
173 key: impl AsRef<Path>,
174 ) -> Result<Self, OpenSSLError> {
175 let acceptor = Arc::new(ArcSwap::from_pointee(config_from_pem_chain_file(
176 chain, key,
177 )?));
178
179 Ok(OpenSSLConfig { acceptor })
180 }
181
182 pub fn get_inner(&self) -> Arc<SslAcceptor> {
184 self.acceptor.load_full()
185 }
186
187 pub fn reload_from_acceptor(&self, acceptor: Arc<SslAcceptor>) {
189 self.acceptor.store(acceptor);
190 }
191
192 pub fn reload_from_der(&self, cert: &[u8], key: &[u8]) -> Result<(), OpenSSLError> {
194 let acceptor = Arc::new(config_from_der(cert, key)?);
195 self.acceptor.store(acceptor);
196
197 Ok(())
198 }
199
200 pub fn reload_from_pem(&self, cert: &[u8], key: &[u8]) -> Result<(), OpenSSLError> {
202 let acceptor = Arc::new(config_from_pem(cert, key)?);
203 self.acceptor.store(acceptor);
204
205 Ok(())
206 }
207
208 pub fn reload_from_pem_file(
210 &self,
211 cert: impl AsRef<Path>,
212 key: impl AsRef<Path>,
213 ) -> Result<(), OpenSSLError> {
214 let acceptor = Arc::new(config_from_pem_file(cert, key)?);
215 self.acceptor.store(acceptor);
216
217 Ok(())
218 }
219
220 pub fn reload_from_pem_chain_file(
222 &self,
223 chain: impl AsRef<Path>,
224 key: impl AsRef<Path>,
225 ) -> Result<(), OpenSSLError> {
226 let acceptor = Arc::new(config_from_pem_chain_file(chain, key)?);
227 self.acceptor.store(acceptor);
228
229 Ok(())
230 }
231}
232
233impl TryFrom<SslAcceptorBuilder> for OpenSSLConfig {
234 type Error = OpenSSLError;
235
236 fn try_from(mut tls_builder: SslAcceptorBuilder) -> Result<Self, Self::Error> {
257 tls_builder.check_private_key()?;
259 tls_builder.set_alpn_select_callback(alpn_select);
260
261 let acceptor = Arc::new(ArcSwap::from_pointee(tls_builder.build()));
262
263 Ok(OpenSSLConfig { acceptor })
264 }
265}
266
267impl fmt::Debug for OpenSSLConfig {
268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269 f.debug_struct("OpenSSLConfig").finish()
270 }
271}
272
273fn alpn_select<'a>(_tls: &mut SslRef, client: &'a [u8]) -> Result<&'a [u8], AlpnError> {
274 ssl::select_next_proto(b"\x02h2\x08http/1.1", client).ok_or(AlpnError::NOACK)
275}
276
277fn config_from_der(cert: &[u8], key: &[u8]) -> Result<SslAcceptor, OpenSSLError> {
278 let cert = X509::from_der(cert)?;
279 let key = PKey::private_key_from_der(key)?;
280
281 let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
282 tls_builder.set_certificate(&cert)?;
283 tls_builder.set_private_key(&key)?;
284 tls_builder.check_private_key()?;
285 tls_builder.set_alpn_select_callback(alpn_select);
286
287 let acceptor = tls_builder.build();
288 Ok(acceptor)
289}
290
291fn config_from_pem(cert: &[u8], key: &[u8]) -> Result<SslAcceptor, OpenSSLError> {
292 let cert = X509::from_pem(cert)?;
293 let key = PKey::private_key_from_pem(key)?;
294
295 let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
296 tls_builder.set_certificate(&cert)?;
297 tls_builder.set_private_key(&key)?;
298 tls_builder.check_private_key()?;
299 tls_builder.set_alpn_select_callback(alpn_select);
300
301 let acceptor = tls_builder.build();
302 Ok(acceptor)
303}
304
305fn config_from_pem_file(
306 cert: impl AsRef<Path>,
307 key: impl AsRef<Path>,
308) -> Result<SslAcceptor, OpenSSLError> {
309 let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
310 tls_builder.set_certificate_file(cert, SslFiletype::PEM)?;
311 tls_builder.set_private_key_file(key, SslFiletype::PEM)?;
312 tls_builder.check_private_key()?;
313 tls_builder.set_alpn_select_callback(alpn_select);
314
315 let acceptor = tls_builder.build();
316 Ok(acceptor)
317}
318
319fn config_from_pem_chain_file(
320 chain: impl AsRef<Path>,
321 key: impl AsRef<Path>,
322) -> Result<SslAcceptor, OpenSSLError> {
323 let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?;
324 tls_builder.set_certificate_chain_file(chain)?;
325 tls_builder.set_private_key_file(key, SslFiletype::PEM)?;
326 tls_builder.check_private_key()?;
327 tls_builder.set_alpn_select_callback(alpn_select);
328
329 let acceptor = tls_builder.build();
330 Ok(acceptor)
331}
332
333#[cfg(test)]
334mod tests {
335 use crate::{
336 handle::Handle,
337 tls_openssl::{self, OpenSSLConfig},
338 };
339 use axum::body::Body;
340 use axum::routing::{get, post};
341 use axum::Router;
342 use bytes::Bytes;
343 use http::{response, Request};
344 use http_body_util::BodyExt;
345 use hyper::client::conn::http1::{handshake, SendRequest};
346 use hyper_util::rt::TokioIo;
347 use std::{io, net::SocketAddr, time::Duration};
348 use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
349
350 use crate::server::tests::slow_body;
351 use axum::response::Response;
352 use openssl::{
353 ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode},
354 x509::X509,
355 };
356 use std::pin::Pin;
357 use tokio::sync::oneshot;
358 use tokio_openssl::SslStream;
359
360 #[tokio::test]
361 async fn start_and_request() {
362 let (_handle, _server_task, addr) = start_server().await;
363
364 let (mut client, _conn) = connect(addr).await;
365
366 let (_parts, body) = send_empty_request(&mut client).await;
367
368 assert_eq!(body.as_ref(), b"Hello, world!");
369 }
370
371 #[tokio::test]
372 async fn test_reload() {
373 let handle = Handle::new();
374
375 let config = OpenSSLConfig::from_pem_file(
376 "examples/self-signed-certs/cert.pem",
377 "examples/self-signed-certs/key.pem",
378 )
379 .unwrap();
380
381 let server_handle = handle.clone();
382 let openssl_config = config.clone();
383 tokio::spawn(async move {
384 let app = Router::new().route("/", get(|| async { "Hello, world!" }));
385
386 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
387
388 tls_openssl::bind_openssl(addr, openssl_config)
389 .handle(server_handle)
390 .serve(app.into_make_service())
391 .await
392 });
393
394 let addr = handle.listening().await.unwrap();
395
396 let cert_a = get_first_cert(addr).await;
397 let mut cert_b = get_first_cert(addr).await;
398
399 assert_eq!(cert_a, cert_b);
400
401 config
402 .reload_from_pem_file(
403 "examples/self-signed-certs/reload/cert.pem",
404 "examples/self-signed-certs/reload/key.pem",
405 )
406 .unwrap();
407
408 cert_b = get_first_cert(addr).await;
409
410 assert_ne!(cert_a, cert_b);
411
412 config
413 .reload_from_pem_file(
414 "examples/self-signed-certs/cert.pem",
415 "examples/self-signed-certs/key.pem",
416 )
417 .unwrap();
418
419 cert_b = get_first_cert(addr).await;
420
421 assert_eq!(cert_a, cert_b);
422 }
423
424 #[tokio::test]
425 async fn test_shutdown() {
426 let (handle, _server_task, addr) = start_server().await;
427
428 let (mut client, conn) = connect(addr).await;
429
430 handle.shutdown();
431
432 let response_future_result = client.send_request(Request::new(Body::empty())).await;
433
434 assert!(response_future_result.is_err());
435
436 let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
438 }
439
440 #[tokio::test]
441 async fn test_graceful_shutdown_timeout() {
442 let (handle, server_task, addr) = start_server().await;
443
444 let (mut client1, _conn1) = connect(addr).await;
445 let (mut client2, _conn2) = connect(addr).await;
446
447 crate::server::tests::do_empty_request(&mut client1)
449 .await
450 .unwrap();
451 crate::server::tests::do_empty_request(&mut client2)
452 .await
453 .unwrap();
454
455 let start = tokio::time::Instant::now();
456
457 let (hdr1_tx, hdr1_rx) = oneshot::channel::<()>();
458
459 let task1 = async {
460 let hdr1 =
463 crate::server::tests::send_slow_request(&mut client1, Duration::from_millis(222))
464 .await;
465 hdr1_tx.send(()).unwrap();
466
467 let res1 = crate::server::tests::recv_slow_response_body(hdr1.unwrap()).await;
468 res1.unwrap();
469 };
470 let task2 = async {
471 let hdr2 =
476 crate::server::tests::send_slow_request(&mut client2, Duration::from_millis(5_555))
477 .await;
478 hdr2.unwrap_err();
479 };
480 let task3 = async {
481 hdr1_rx.await.unwrap();
483
484 handle.graceful_shutdown(Some(Duration::from_millis(333)));
486
487 timeout(Duration::from_secs(1), server_task)
489 .await
490 .unwrap()
491 .unwrap()
492 .unwrap();
493
494 assert!(start.elapsed() >= Duration::from_millis(222 + 333));
496 assert!(start.elapsed() <= Duration::from_millis(5_555));
497 };
498
499 tokio::join!(task1, task2, task3);
500 }
501
502 async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
503 let handle = Handle::new();
504
505 let server_handle = handle.clone();
506 let server_task = tokio::spawn(async move {
507 let app = Router::new()
508 .route("/", get(|| async { "Hello, world!" }))
509 .route(
510 "/echo_slowly",
511 post(|body: Bytes| async move {
512 Response::new(slow_body(body.len(), Duration::from_millis(100)))
514 }),
515 );
516
517 let config = OpenSSLConfig::from_pem_file(
518 "examples/self-signed-certs/cert.pem",
519 "examples/self-signed-certs/key.pem",
520 )
521 .unwrap();
522
523 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
524
525 tls_openssl::bind_openssl(addr, config)
526 .handle(server_handle)
527 .serve(app.into_make_service())
528 .await
529 });
530
531 let addr = handle.listening().await.unwrap();
532
533 (handle, server_task, addr)
534 }
535
536 async fn get_first_cert(addr: SocketAddr) -> X509 {
537 let stream = TcpStream::connect(addr).await.unwrap();
538 let tls_stream = tls_connector(dns_name(), stream).await;
539
540 tls_stream.ssl().peer_certificate().unwrap()
541 }
542
543 async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
544 let stream = TcpStream::connect(addr).await.unwrap();
545 let tls_stream = TokioIo::new(tls_connector(dns_name(), stream).await);
546
547 let (send_request, connection) = handshake(tls_stream).await.unwrap();
548
549 let task = tokio::spawn(async move {
550 let _ = connection.await;
551 });
552
553 (send_request, task)
554 }
555
556 async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
557 let (parts, body) = client
558 .send_request(Request::new(Body::empty()))
559 .await
560 .unwrap()
561 .into_parts();
562 let body = body.collect().await.unwrap().to_bytes();
563
564 (parts, body)
565 }
566
567 async fn tls_connector(hostname: &str, stream: TcpStream) -> SslStream<TcpStream> {
568 let mut tls_parms = SslConnector::builder(SslMethod::tls_client()).unwrap();
569 tls_parms.set_verify(SslVerifyMode::NONE);
570 let hostname_owned = hostname.to_string();
571 tls_parms.set_client_hello_callback(move |ssl_ref, _ssl_alert| {
572 ssl_ref
573 .set_hostname(hostname_owned.as_str())
574 .map(|()| openssl::ssl::ClientHelloResponse::SUCCESS)
575 });
576 let tls_parms = tls_parms.build();
577
578 let ssl = Ssl::new(tls_parms.context()).unwrap();
579 let mut tls_stream = SslStream::new(ssl, stream).unwrap();
580
581 SslStream::connect(Pin::new(&mut tls_stream)).await.unwrap();
582
583 tls_stream
584 }
585
586 fn dns_name() -> &'static str {
587 "localhost"
588 }
589}