1use self::future::RustlsAcceptorFuture;
31use crate::{
32 accept::{Accept, DefaultAcceptor},
33 server::{io_other, Server},
34};
35use arc_swap::ArcSwap;
36use rustls::ServerConfig;
37use rustls_pemfile::Item;
38use rustls_pki_types::{CertificateDer, PrivateKeyDer};
39use std::time::Duration;
40use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc};
41use tokio::{
42 io::{AsyncRead, AsyncWrite},
43 task::spawn_blocking,
44};
45use tokio_rustls::server::TlsStream;
46
47pub(crate) mod export {
48 use super::*;
49
50 #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
52 pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
53 super::bind_rustls(addr, config)
54 }
55
56 #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
58 pub fn from_tcp_rustls(
59 listener: std::net::TcpListener,
60 config: RustlsConfig,
61 ) -> Server<RustlsAcceptor> {
62 let acceptor = RustlsAcceptor::new(config);
63
64 Server::from_tcp(listener).acceptor(acceptor)
65 }
66}
67
68pub mod future;
69
70pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
72 let acceptor = RustlsAcceptor::new(config);
73
74 Server::bind(addr).acceptor(acceptor)
75}
76
77pub fn from_tcp_rustls(
79 listener: std::net::TcpListener,
80 config: RustlsConfig,
81) -> Server<RustlsAcceptor> {
82 let acceptor = RustlsAcceptor::new(config);
83
84 Server::from_tcp(listener).acceptor(acceptor)
85}
86
87#[derive(Clone)]
89pub struct RustlsAcceptor<A = DefaultAcceptor> {
90 inner: A,
91 config: RustlsConfig,
92 handshake_timeout: Duration,
93}
94
95impl RustlsAcceptor {
96 pub fn new(config: RustlsConfig) -> Self {
98 let inner = DefaultAcceptor::new();
99
100 #[cfg(not(test))]
101 let handshake_timeout = Duration::from_secs(10);
102
103 #[cfg(test)]
105 let handshake_timeout = Duration::from_secs(1);
106
107 Self {
108 inner,
109 config,
110 handshake_timeout,
111 }
112 }
113
114 pub fn handshake_timeout(mut self, val: Duration) -> Self {
116 self.handshake_timeout = val;
117 self
118 }
119}
120
121impl<A> RustlsAcceptor<A> {
122 pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> RustlsAcceptor<Acceptor> {
124 RustlsAcceptor {
125 inner: acceptor,
126 config: self.config,
127 handshake_timeout: self.handshake_timeout,
128 }
129 }
130}
131
132impl<A, I, S> Accept<I, S> for RustlsAcceptor<A>
133where
134 A: Accept<I, S>,
135 A::Stream: AsyncRead + AsyncWrite + Unpin,
136{
137 type Stream = TlsStream<A::Stream>;
138 type Service = A::Service;
139 type Future = RustlsAcceptorFuture<A::Future, A::Stream, A::Service>;
140
141 fn accept(&self, stream: I, service: S) -> Self::Future {
142 let inner_future = self.inner.accept(stream, service);
143 let config = self.config.clone();
144
145 RustlsAcceptorFuture::new(inner_future, config, self.handshake_timeout)
146 }
147}
148
149impl<A> fmt::Debug for RustlsAcceptor<A> {
150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 f.debug_struct("RustlsAcceptor").finish()
152 }
153}
154
155#[derive(Clone)]
157pub struct RustlsConfig {
158 inner: Arc<ArcSwap<ServerConfig>>,
159}
160
161impl RustlsConfig {
162 pub fn from_config(config: Arc<ServerConfig>) -> Self {
166 let inner = Arc::new(ArcSwap::new(config));
167
168 Self { inner }
169 }
170
171 pub async fn from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<Self> {
177 let server_config = spawn_blocking(|| config_from_der(cert, key))
178 .await
179 .unwrap()?;
180 let inner = Arc::new(ArcSwap::from_pointee(server_config));
181
182 Ok(Self { inner })
183 }
184
185 pub async fn from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<Self> {
189 let server_config = spawn_blocking(|| config_from_pem(cert, key))
190 .await
191 .unwrap()?;
192 let inner = Arc::new(ArcSwap::from_pointee(server_config));
193
194 Ok(Self { inner })
195 }
196
197 pub async fn from_pem_file(cert: impl AsRef<Path>, key: impl AsRef<Path>) -> io::Result<Self> {
201 let server_config = config_from_pem_file(cert, key).await?;
202 let inner = Arc::new(ArcSwap::from_pointee(server_config));
203
204 Ok(Self { inner })
205 }
206
207 pub fn get_inner(&self) -> Arc<ServerConfig> {
209 self.inner.load_full()
210 }
211
212 pub fn reload_from_config(&self, config: Arc<ServerConfig>) {
214 self.inner.store(config);
215 }
216
217 pub async fn reload_from_der(&self, cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<()> {
223 let server_config = spawn_blocking(|| config_from_der(cert, key))
224 .await
225 .unwrap()?;
226 let inner = Arc::new(server_config);
227
228 self.inner.store(inner);
229
230 Ok(())
231 }
232
233 pub async fn from_pem_chain_file(
236 chain: impl AsRef<Path>,
237 key: impl AsRef<Path>,
238 ) -> io::Result<Self> {
239 let server_config = config_from_pem_chain_file(chain, key).await?;
240 let inner = Arc::new(ArcSwap::from_pointee(server_config));
241
242 Ok(Self { inner })
243 }
244
245 pub async fn reload_from_pem(&self, cert: Vec<u8>, key: Vec<u8>) -> io::Result<()> {
249 let server_config = spawn_blocking(|| config_from_pem(cert, key))
250 .await
251 .unwrap()?;
252 let inner = Arc::new(server_config);
253
254 self.inner.store(inner);
255
256 Ok(())
257 }
258
259 pub async fn reload_from_pem_file(
263 &self,
264 cert: impl AsRef<Path>,
265 key: impl AsRef<Path>,
266 ) -> io::Result<()> {
267 let server_config = config_from_pem_file(cert, key).await?;
268 let inner = Arc::new(server_config);
269
270 self.inner.store(inner);
271
272 Ok(())
273 }
274}
275
276impl fmt::Debug for RustlsConfig {
277 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278 f.debug_struct("RustlsConfig").finish()
279 }
280}
281
282fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig> {
283 let cert = cert.into_iter().map(CertificateDer::from).collect();
284 let key = PrivateKeyDer::try_from(key).map_err(io_other)?;
285
286 let mut config = ServerConfig::builder()
287 .with_no_client_auth()
288 .with_single_cert(cert, key)
289 .map_err(io_other)?;
290
291 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
292
293 Ok(config)
294}
295
296fn config_from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<ServerConfig> {
297 let cert = rustls_pemfile::certs(&mut cert.as_ref())
298 .map(|it| it.map(|it| it.to_vec()))
299 .collect::<Result<Vec<_>, _>>()?;
300 let mut key_vec: Vec<Vec<u8>> = rustls_pemfile::read_all(&mut key.as_ref())
302 .filter_map(|i| match i.ok()? {
303 Item::Sec1Key(key) => Some(key.secret_sec1_der().to_vec()),
304 Item::Pkcs1Key(key) => Some(key.secret_pkcs1_der().to_vec().into()),
305 Item::Pkcs8Key(key) => Some(key.secret_pkcs8_der().to_vec().into()),
306 _ => None,
307 })
308 .collect();
309
310 if key_vec.len() != 1 {
312 return Err(io_other("private key format not supported"));
313 }
314
315 config_from_der(cert, key_vec.pop().unwrap())
316}
317
318async fn config_from_pem_file(
319 cert: impl AsRef<Path>,
320 key: impl AsRef<Path>,
321) -> io::Result<ServerConfig> {
322 let cert = tokio::fs::read(cert.as_ref()).await?;
323 let key = tokio::fs::read(key.as_ref()).await?;
324
325 config_from_pem(cert, key)
326}
327
328async fn config_from_pem_chain_file(
329 cert: impl AsRef<Path>,
330 chain: impl AsRef<Path>,
331) -> io::Result<ServerConfig> {
332 let cert = tokio::fs::read(cert.as_ref()).await?;
333 let cert = rustls_pemfile::certs(&mut cert.as_ref())
334 .map(|it| it.map(|it| CertificateDer::from(it.to_vec())))
335 .collect::<Result<Vec<_>, _>>()?;
336 let key = tokio::fs::read(chain.as_ref()).await?;
337 let key_cert: PrivateKeyDer = match rustls_pemfile::read_one(&mut key.as_ref())?
338 .ok_or_else(|| io_other("could not parse pem file"))?
339 {
340 Item::Pkcs8Key(key) => {
341 Ok(PrivateKeyDer::try_from(key.secret_pkcs8_der().to_vec()).map_err(io_other)?)
342 }
343 x => Err(io_other(format!(
344 "invalid certificate format, received: {x:?}"
345 ))),
346 }?;
347
348 ServerConfig::builder()
349 .with_no_client_auth()
350 .with_single_cert(cert, key_cert)
351 .map_err(|_| io_other("invalid certificate"))
352}
353
354#[cfg(test)]
355mod tests {
356 use crate::handle::Handle;
357 use crate::server::tests::slow_body;
358 use crate::tls_rustls::{self, RustlsConfig};
359 use axum::body::Body;
360 use axum::response::Response;
361 use axum::routing::{get, post};
362 use axum::Router;
363 use bytes::Bytes;
364 use http::{response, Request};
365 use http_body_util::BodyExt;
366 use hyper::client::conn::http1::{handshake, SendRequest};
367 use hyper_util::rt::TokioIo;
368 use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
369 use rustls::{ClientConfig, DigitallySignedStruct, Error, SignatureScheme};
370 use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
371 use std::fmt::Debug;
372 use std::{convert::TryFrom, io, net::SocketAddr, sync::Arc, time::Duration};
373 use tokio::sync::oneshot;
374 use tokio::time::sleep;
375 use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
376 use tokio_rustls::TlsConnector;
377
378 #[tokio::test]
379 async fn start_and_request() {
380 let (_handle, _server_task, addr) = start_server().await;
381
382 let (mut client, _conn) = connect(addr).await;
383
384 let (_parts, body) = send_empty_request(&mut client).await;
385
386 assert_eq!(body.as_ref(), b"Hello, world!");
387 }
388
389 #[ignore]
390 #[tokio::test]
391 async fn tls_timeout() {
392 let (handle, _server_task, addr) = start_server().await;
393 assert_eq!(handle.connection_count(), 0);
394
395 let _stream = TcpStream::connect(addr).await.unwrap();
397
398 sleep(Duration::from_millis(500)).await;
399 assert_eq!(handle.connection_count(), 1);
400
401 tokio::time::sleep(Duration::from_millis(1000)).await;
402 assert_eq!(handle.connection_count(), 0);
404 }
405
406 #[tokio::test]
407 async fn test_reload() {
408 let handle = Handle::new();
409
410 let config = RustlsConfig::from_pem_file(
411 "examples/self-signed-certs/cert.pem",
412 "examples/self-signed-certs/key.pem",
413 )
414 .await
415 .unwrap();
416
417 let server_handle = handle.clone();
418 let rustls_config = config.clone();
419 tokio::spawn(async move {
420 let app = Router::new().route("/", get(|| async { "Hello, world!" }));
421
422 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
423
424 tls_rustls::bind_rustls(addr, rustls_config)
425 .handle(server_handle)
426 .serve(app.into_make_service())
427 .await
428 });
429
430 let addr = handle.listening().await.unwrap();
431
432 let cert_a = get_first_cert(addr).await;
433 let mut cert_b = get_first_cert(addr).await;
434
435 assert_eq!(cert_a, cert_b);
436
437 config
438 .reload_from_pem_file(
439 "examples/self-signed-certs/reload/cert.pem",
440 "examples/self-signed-certs/reload/key.pem",
441 )
442 .await
443 .unwrap();
444
445 cert_b = get_first_cert(addr).await;
446
447 assert_ne!(cert_a, cert_b);
448
449 config
450 .reload_from_pem_file(
451 "examples/self-signed-certs/cert.pem",
452 "examples/self-signed-certs/key.pem",
453 )
454 .await
455 .unwrap();
456
457 cert_b = get_first_cert(addr).await;
458
459 assert_eq!(cert_a, cert_b);
460 }
461
462 #[tokio::test]
463 async fn test_shutdown() {
464 let (handle, _server_task, addr) = start_server().await;
465
466 let (mut client, conn) = connect(addr).await;
467
468 handle.shutdown();
469
470 let response_future_result = client.send_request(Request::new(Body::empty())).await;
471
472 assert!(response_future_result.is_err());
473
474 let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
476 }
477
478 #[tokio::test]
479 async fn test_graceful_shutdown_timeout() {
480 let (handle, server_task, addr) = start_server().await;
481
482 let (mut client1, _conn1) = connect(addr).await;
483 let (mut client2, _conn2) = connect(addr).await;
484
485 crate::server::tests::do_empty_request(&mut client1)
487 .await
488 .unwrap();
489 crate::server::tests::do_empty_request(&mut client2)
490 .await
491 .unwrap();
492
493 let start = tokio::time::Instant::now();
494
495 let (hdr1_tx, hdr1_rx) = oneshot::channel::<()>();
496
497 let task1 = async {
498 let hdr1 =
501 crate::server::tests::send_slow_request(&mut client1, Duration::from_millis(222))
502 .await;
503 hdr1_tx.send(()).unwrap();
504
505 let res1 = crate::server::tests::recv_slow_response_body(hdr1.unwrap()).await;
506 res1.unwrap();
507 };
508 let task2 = async {
509 let hdr2 =
514 crate::server::tests::send_slow_request(&mut client2, Duration::from_millis(5_555))
515 .await;
516 hdr2.unwrap_err();
517 };
518 let task3 = async {
519 hdr1_rx.await.unwrap();
521
522 handle.graceful_shutdown(Some(Duration::from_millis(333)));
524
525 timeout(Duration::from_secs(1), server_task)
527 .await
528 .unwrap()
529 .unwrap()
530 .unwrap();
531
532 assert!(start.elapsed() >= Duration::from_millis(222 + 333));
534 assert!(start.elapsed() <= Duration::from_millis(5_555));
535 };
536
537 tokio::join!(task1, task2, task3);
538 }
539
540 async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
541 let handle = Handle::new();
542
543 let server_handle = handle.clone();
544 let server_task = tokio::spawn(async move {
545 let app = Router::new()
546 .route("/", get(|| async { "Hello, world!" }))
547 .route(
548 "/echo_slowly",
549 post(|body: Bytes| async move {
550 Response::new(slow_body(body.len(), Duration::from_millis(100)))
552 }),
553 );
554
555 let config = RustlsConfig::from_pem_file(
556 "examples/self-signed-certs/cert.pem",
557 "examples/self-signed-certs/key.pem",
558 )
559 .await?;
560
561 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
562
563 tls_rustls::bind_rustls(addr, config)
564 .handle(server_handle)
565 .serve(app.into_make_service())
566 .await
567 });
568
569 let addr = handle.listening().await.unwrap();
570
571 (handle, server_task, addr)
572 }
573
574 async fn get_first_cert(addr: SocketAddr) -> CertificateDer<'static> {
575 let stream = TcpStream::connect(addr).await.unwrap();
576 let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();
577
578 let (_io, client_connection) = tls_stream.into_inner();
579
580 client_connection.peer_certificates().unwrap()[0].clone()
581 }
582
583 async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
584 let stream = TcpStream::connect(addr).await.unwrap();
585 let tls_stream = TokioIo::new(tls_connector().connect(dns_name(), stream).await.unwrap());
586
587 let (send_request, connection) = handshake(tls_stream).await.unwrap();
588
589 let task = tokio::spawn(async move {
590 let _ = connection.await;
591 });
592
593 (send_request, task)
594 }
595
596 async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
597 let (parts, body) = client
598 .send_request(Request::new(Body::empty()))
599 .await
600 .unwrap()
601 .into_parts();
602 let body = body.collect().await.unwrap().to_bytes();
603
604 (parts, body)
605 }
606
607 fn tls_connector() -> TlsConnector {
608 #[derive(Debug)]
609 struct NoVerify;
610
611 impl ServerCertVerifier for NoVerify {
612 fn verify_server_cert(
613 &self,
614 _end_entity: &CertificateDer,
615 _intermediates: &[CertificateDer],
616 _server_name: &ServerName,
617 _ocsp_response: &[u8],
618 _now: UnixTime,
619 ) -> Result<ServerCertVerified, rustls::Error> {
620 Ok(ServerCertVerified::assertion())
621 }
622
623 fn verify_tls12_signature(
624 &self,
625 _message: &[u8],
626 _cert: &CertificateDer<'_>,
627 _dss: &DigitallySignedStruct,
628 ) -> Result<HandshakeSignatureValid, Error> {
629 Ok(HandshakeSignatureValid::assertion())
630 }
631
632 fn verify_tls13_signature(
633 &self,
634 _message: &[u8],
635 _cert: &CertificateDer<'_>,
636 _dss: &DigitallySignedStruct,
637 ) -> Result<HandshakeSignatureValid, Error> {
638 Ok(HandshakeSignatureValid::assertion())
639 }
640
641 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
642 vec![
643 SignatureScheme::RSA_PKCS1_SHA1,
644 SignatureScheme::RSA_PKCS1_SHA256,
645 SignatureScheme::RSA_PKCS1_SHA384,
646 SignatureScheme::RSA_PKCS1_SHA512,
647 SignatureScheme::RSA_PSS_SHA256,
648 SignatureScheme::RSA_PSS_SHA384,
649 SignatureScheme::RSA_PSS_SHA512,
650 ]
651 }
652 }
653
654 let mut client_config = ClientConfig::builder()
655 .dangerous()
656 .with_custom_certificate_verifier(Arc::new(NoVerify))
657 .with_no_client_auth();
658
659 client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
660
661 TlsConnector::from(Arc::new(client_config))
662 }
663
664 fn dns_name() -> ServerName<'static> {
665 ServerName::try_from("localhost").unwrap()
666 }
667}