1use self::future::RustlsAcceptorFuture;
31use crate::{
32 accept::{Accept, DefaultAcceptor},
33 server::{io_other, Server},
34};
35use arc_swap::ArcSwap;
36use rustls::{Certificate, PrivateKey, ServerConfig};
37use std::time::Duration;
38use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc};
39use tokio::{
40 io::{AsyncRead, AsyncWrite},
41 task::spawn_blocking,
42};
43use tokio_rustls::server::TlsStream;
44
45pub(crate) mod export {
47 use super::{RustlsAcceptor, RustlsConfig, Server, SocketAddr};
48
49 #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
51 pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
52 super::bind_rustls(addr, config)
53 }
54
55 #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
57 pub fn from_tcp_rustls(
58 listener: std::net::TcpListener,
59 config: RustlsConfig,
60 ) -> Server<RustlsAcceptor> {
61 let acceptor = RustlsAcceptor::new(config);
62
63 Server::from_tcp(listener).acceptor(acceptor)
64 }
65}
66
67pub mod future;
68
69pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
71 let acceptor = RustlsAcceptor::new(config);
72
73 Server::bind(addr).acceptor(acceptor)
74}
75
76pub fn from_tcp_rustls(
78 listener: std::net::TcpListener,
79 config: RustlsConfig,
80) -> Server<RustlsAcceptor> {
81 let acceptor = RustlsAcceptor::new(config);
82
83 Server::from_tcp(listener).acceptor(acceptor)
84}
85
86#[derive(Clone)]
88pub struct RustlsAcceptor<A = DefaultAcceptor> {
89 inner: A,
90 config: RustlsConfig,
91 handshake_timeout: Duration,
92}
93
94impl RustlsAcceptor {
95 pub fn new(config: RustlsConfig) -> Self {
97 let inner = DefaultAcceptor::new();
98
99 #[cfg(not(test))]
102 let handshake_timeout = Duration::from_secs(10);
103 #[cfg(test)]
104 let handshake_timeout = Duration::from_secs(1);
105
106 Self {
107 inner,
108 config,
109 handshake_timeout,
110 }
111 }
112
113 pub fn handshake_timeout(mut self, val: Duration) -> Self {
115 self.handshake_timeout = val;
116 self
117 }
118}
119
120impl<A> RustlsAcceptor<A> {
121 pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> RustlsAcceptor<Acceptor> {
123 RustlsAcceptor {
124 inner: acceptor,
125 config: self.config,
126 handshake_timeout: self.handshake_timeout,
127 }
128 }
129}
130
131impl<A, I, S> Accept<I, S> for RustlsAcceptor<A>
133where
134 A: Accept<I, S>,
135 A::Stream: AsyncRead + AsyncWrite + Unpin,
136{
137 type Stream = TlsStream<A::Stream>;
138 type Service = A::Service;
139 type Future = RustlsAcceptorFuture<A::Future, A::Stream, A::Service>;
140
141 fn accept(&self, stream: I, service: S) -> Self::Future {
142 let inner_future = self.inner.accept(stream, service);
143 let config = self.config.clone();
144
145 RustlsAcceptorFuture::new(inner_future, config, self.handshake_timeout)
146 }
147}
148
149impl<A> fmt::Debug for RustlsAcceptor<A> {
150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 f.debug_struct("RustlsAcceptor").finish()
152 }
153}
154
155#[derive(Clone)]
157pub struct RustlsConfig {
158 inner: Arc<ArcSwap<ServerConfig>>,
159}
160
161impl RustlsConfig {
163 pub fn from_config(config: Arc<ServerConfig>) -> Self {
168 let inner = Arc::new(ArcSwap::new(config));
169 Self { inner }
170 }
171
172 pub async fn from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<Self> {
178 let server_config = spawn_blocking(|| config_from_der(cert, key))
179 .await
180 .unwrap()?;
181 let inner = Arc::new(ArcSwap::from_pointee(server_config));
182 Ok(Self { inner })
183 }
184
185 pub async fn from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<Self> {
190 let server_config = spawn_blocking(|| config_from_pem(cert, key))
191 .await
192 .unwrap()?;
193 let inner = Arc::new(ArcSwap::from_pointee(server_config));
194 Ok(Self { inner })
195 }
196
197 pub async fn from_pem_file(cert: impl AsRef<Path>, key: impl AsRef<Path>) -> io::Result<Self> {
201 let server_config = config_from_pem_file(cert, key).await?;
202 let inner = Arc::new(ArcSwap::from_pointee(server_config));
203 Ok(Self { inner })
204 }
205
206 pub fn get_inner(&self) -> Arc<ServerConfig> {
208 self.inner.load_full()
209 }
210
211 pub fn reload_from_config(&self, config: Arc<ServerConfig>) {
213 self.inner.store(config);
214 }
215
216 pub async fn reload_from_der(&self, cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<()> {
221 let server_config = spawn_blocking(|| config_from_der(cert, key))
222 .await
223 .unwrap()?;
224 let inner = Arc::new(server_config);
225 self.inner.store(inner);
226 Ok(())
227 }
228
229 pub async fn reload_from_pem(&self, cert: Vec<u8>, key: Vec<u8>) -> io::Result<()> {
231 let server_config = spawn_blocking(|| config_from_pem(cert, key))
232 .await
233 .unwrap()?;
234 let inner = Arc::new(server_config);
235 self.inner.store(inner);
236 Ok(())
237 }
238
239 pub async fn reload_from_pem_file(
241 &self,
242 cert: impl AsRef<Path>,
243 key: impl AsRef<Path>,
244 ) -> io::Result<()> {
245 let server_config = config_from_pem_file(cert, key).await?;
246 let inner = Arc::new(server_config);
247 self.inner.store(inner);
248 Ok(())
249 }
250}
251
252impl fmt::Debug for RustlsConfig {
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 f.debug_struct("RustlsConfig").finish()
256 }
257}
258
259fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig> {
261 let cert = cert.into_iter().map(Certificate).collect();
263 let key = PrivateKey(key);
264
265 let mut config = ServerConfig::builder()
267 .with_safe_defaults()
268 .with_no_client_auth()
269 .with_single_cert(cert, key)
270 .map_err(io_other)?;
271
272 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
274
275 Ok(config)
276}
277
278fn config_from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<ServerConfig> {
280 use rustls_pemfile::Item;
281
282 let cert = rustls_pemfile::certs(&mut cert.as_ref())?;
284 let key = match rustls_pemfile::read_one(&mut key.as_ref())? {
285 Some(Item::RSAKey(key)) | Some(Item::PKCS8Key(key)) | Some(Item::ECKey(key)) => key,
286 _ => return Err(io_other("private key format not supported")),
287 };
288
289 config_from_der(cert, key)
290}
291
292async fn config_from_pem_file(
294 cert: impl AsRef<Path>,
295 key: impl AsRef<Path>,
296) -> io::Result<ServerConfig> {
297 let cert = tokio::fs::read(cert.as_ref()).await?;
299 let key = tokio::fs::read(key.as_ref()).await?;
300
301 config_from_pem(cert, key)
302}
303
304#[cfg(test)]
305pub(crate) mod tests {
306 use crate::{
307 handle::Handle,
308 tls_rustls::{self, RustlsConfig},
309 };
310 use axum::{routing::get, Router};
311 use bytes::Bytes;
312 use http::{response, Request};
313 use hyper::{
314 client::conn::{handshake, SendRequest},
315 Body,
316 };
317 use rustls::{
318 client::{ServerCertVerified, ServerCertVerifier},
319 Certificate, ClientConfig, ServerName,
320 };
321 use std::{
322 convert::TryFrom,
323 io,
324 net::SocketAddr,
325 sync::Arc,
326 time::{Duration, SystemTime},
327 };
328 use tokio::time::sleep;
329 use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
330 use tokio_rustls::TlsConnector;
331 use tower::{Service, ServiceExt};
332
333 #[tokio::test]
334 async fn start_and_request() {
335 let (_handle, _server_task, addr) = start_server().await;
336
337 let (mut client, _conn) = connect(addr).await;
338
339 let (_parts, body) = send_empty_request(&mut client).await;
340
341 assert_eq!(body.as_ref(), b"Hello, world!");
342 }
343
344 #[tokio::test]
345 async fn tls_timeout() {
346 let (handle, _server_task, addr) = start_server().await;
347 assert_eq!(handle.connection_count(), 0);
348
349 let _stream = TcpStream::connect(addr).await.unwrap();
351
352 sleep(Duration::from_millis(500)).await;
353 assert_eq!(handle.connection_count(), 1);
354
355 tokio::time::sleep(Duration::from_millis(1000)).await;
356 assert_eq!(handle.connection_count(), 0);
358 }
359
360 #[tokio::test]
361 async fn test_reload() {
362 let handle = Handle::new();
363
364 let config = RustlsConfig::from_pem_file(
365 "examples/self-signed-certs/cert.pem",
366 "examples/self-signed-certs/key.pem",
367 )
368 .await
369 .unwrap();
370
371 let server_handle = handle.clone();
372 let rustls_config = config.clone();
373 tokio::spawn(async move {
374 let app = Router::new().route("/", get(|| async { "Hello, world!" }));
375
376 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
377
378 tls_rustls::bind_rustls(addr, rustls_config)
379 .handle(server_handle)
380 .serve(app.into_make_service())
381 .await
382 });
383
384 let addr = handle.listening().await.unwrap();
385
386 let cert_a = get_first_cert(addr).await;
387 let mut cert_b = get_first_cert(addr).await;
388
389 assert_eq!(cert_a, cert_b);
390
391 config
392 .reload_from_pem_file(
393 "examples/self-signed-certs/reload/cert.pem",
394 "examples/self-signed-certs/reload/key.pem",
395 )
396 .await
397 .unwrap();
398
399 cert_b = get_first_cert(addr).await;
400
401 assert_ne!(cert_a, cert_b);
402
403 config
404 .reload_from_pem_file(
405 "examples/self-signed-certs/cert.pem",
406 "examples/self-signed-certs/key.pem",
407 )
408 .await
409 .unwrap();
410
411 cert_b = get_first_cert(addr).await;
412
413 assert_eq!(cert_a, cert_b);
414 }
415
416 #[tokio::test]
417 async fn test_shutdown() {
418 let (handle, _server_task, addr) = start_server().await;
419
420 let (mut client, conn) = connect(addr).await;
421
422 handle.shutdown();
423
424 let response_future_result = client
425 .ready()
426 .await
427 .unwrap()
428 .call(Request::new(Body::empty()))
429 .await;
430
431 assert!(response_future_result.is_err());
432
433 let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
435 }
436
437 #[tokio::test]
438 async fn test_graceful_shutdown() {
439 let (handle, server_task, addr) = start_server().await;
440
441 let (mut client, conn) = connect(addr).await;
442
443 handle.graceful_shutdown(None);
444
445 let (_parts, body) = send_empty_request(&mut client).await;
446
447 assert_eq!(body.as_ref(), b"Hello, world!");
448
449 conn.abort();
451
452 let server_result = timeout(Duration::from_secs(1), server_task)
454 .await
455 .unwrap()
456 .unwrap();
457
458 assert!(server_result.is_ok());
459 }
460
461 #[tokio::test]
462 async fn test_graceful_shutdown_timed() {
463 let (handle, server_task, addr) = start_server().await;
464
465 let (mut client, _conn) = connect(addr).await;
466
467 handle.graceful_shutdown(Some(Duration::from_millis(250)));
468
469 let (_parts, body) = send_empty_request(&mut client).await;
470
471 assert_eq!(body.as_ref(), b"Hello, world!");
472
473 let server_result = timeout(Duration::from_secs(1), server_task)
478 .await
479 .unwrap()
480 .unwrap();
481
482 assert!(server_result.is_ok());
483 }
484
485 async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
486 let handle = Handle::new();
487
488 let server_handle = handle.clone();
489 let server_task = tokio::spawn(async move {
490 let app = Router::new().route("/", get(|| async { "Hello, world!" }));
491
492 let config = RustlsConfig::from_pem_file(
493 "examples/self-signed-certs/cert.pem",
494 "examples/self-signed-certs/key.pem",
495 )
496 .await?;
497
498 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
499
500 tls_rustls::bind_rustls(addr, config)
501 .handle(server_handle)
502 .serve(app.into_make_service())
503 .await
504 });
505
506 let addr = handle.listening().await.unwrap();
507
508 (handle, server_task, addr)
509 }
510
511 async fn get_first_cert(addr: SocketAddr) -> Certificate {
512 let stream = TcpStream::connect(addr).await.unwrap();
513 let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();
514
515 let (_io, client_connection) = tls_stream.into_inner();
516
517 client_connection.peer_certificates().unwrap()[0].clone()
518 }
519
520 async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
521 let stream = TcpStream::connect(addr).await.unwrap();
522 let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();
523
524 let (send_request, connection) = handshake(tls_stream).await.unwrap();
525
526 let task = tokio::spawn(async move {
527 let _ = connection.await;
528 });
529
530 (send_request, task)
531 }
532
533 async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
534 let (parts, body) = client
535 .ready()
536 .await
537 .unwrap()
538 .call(Request::new(Body::empty()))
539 .await
540 .unwrap()
541 .into_parts();
542 let body = hyper::body::to_bytes(body).await.unwrap();
543
544 (parts, body)
545 }
546
547 pub(crate) fn tls_connector() -> TlsConnector {
549 struct NoVerify;
550
551 impl ServerCertVerifier for NoVerify {
552 fn verify_server_cert(
553 &self,
554 _end_entity: &Certificate,
555 _intermediates: &[Certificate],
556 _server_name: &ServerName,
557 _scts: &mut dyn Iterator<Item = &[u8]>,
558 _ocsp_response: &[u8],
559 _now: SystemTime,
560 ) -> Result<ServerCertVerified, rustls::Error> {
561 Ok(ServerCertVerified::assertion())
562 }
563 }
564
565 let mut client_config = ClientConfig::builder()
566 .with_safe_defaults()
567 .with_custom_certificate_verifier(Arc::new(NoVerify))
568 .with_no_client_auth();
569
570 client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
571
572 TlsConnector::from(Arc::new(client_config))
573 }
574
575 pub(crate) fn dns_name() -> ServerName {
577 ServerName::try_from("localhost").unwrap()
578 }
579}