1use self::future::RustlsAcceptorFuture;
31use crate::{
32 accept::{Accept, DefaultAcceptor},
33 server::{io_other, Server},
34};
35use arc_swap::ArcSwap;
36use rustls::ServerConfig;
37use rustls_pemfile::Item;
38use rustls_pki_types::{CertificateDer, PrivateKeyDer};
39use std::time::Duration;
40use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc};
41use tokio::{
42 io::{AsyncRead, AsyncWrite},
43 task::spawn_blocking,
44};
45use tokio_rustls::server::TlsStream;
46
47pub(crate) mod export {
48 #[allow(clippy::wildcard_imports)]
49 use super::*;
50
51 #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
53 pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
54 super::bind_rustls(addr, config)
55 }
56
57 #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
59 pub fn from_tcp_rustls(
60 listener: std::net::TcpListener,
61 config: RustlsConfig,
62 ) -> Server<RustlsAcceptor> {
63 let acceptor = RustlsAcceptor::new(config);
64
65 Server::from_tcp(listener).acceptor(acceptor)
66 }
67}
68
69pub mod future;
70
71pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server<RustlsAcceptor> {
73 let acceptor = RustlsAcceptor::new(config);
74
75 Server::bind(addr).acceptor(acceptor)
76}
77
78pub fn from_tcp_rustls(
80 listener: std::net::TcpListener,
81 config: RustlsConfig,
82) -> Server<RustlsAcceptor> {
83 let acceptor = RustlsAcceptor::new(config);
84
85 Server::from_tcp(listener).acceptor(acceptor)
86}
87
88#[derive(Clone)]
90pub struct RustlsAcceptor<A = DefaultAcceptor> {
91 inner: A,
92 config: RustlsConfig,
93 handshake_timeout: Duration,
94}
95
96impl RustlsAcceptor {
97 pub fn new(config: RustlsConfig) -> Self {
99 let inner = DefaultAcceptor::new();
100
101 #[cfg(not(test))]
102 let handshake_timeout = Duration::from_secs(10);
103
104 #[cfg(test)]
106 let handshake_timeout = Duration::from_secs(1);
107
108 Self {
109 inner,
110 config,
111 handshake_timeout,
112 }
113 }
114
115 pub fn handshake_timeout(mut self, val: Duration) -> Self {
117 self.handshake_timeout = val;
118 self
119 }
120}
121
122impl<A> RustlsAcceptor<A> {
123 pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> RustlsAcceptor<Acceptor> {
125 RustlsAcceptor {
126 inner: acceptor,
127 config: self.config,
128 handshake_timeout: self.handshake_timeout,
129 }
130 }
131}
132
133impl<A, I, S> Accept<I, S> for RustlsAcceptor<A>
134where
135 A: Accept<I, S>,
136 A::Stream: AsyncRead + AsyncWrite + Unpin,
137{
138 type Stream = TlsStream<A::Stream>;
139 type Service = A::Service;
140 type Future = RustlsAcceptorFuture<A::Future, A::Stream, A::Service>;
141
142 fn accept(&self, stream: I, service: S) -> Self::Future {
143 let inner_future = self.inner.accept(stream, service);
144 let config = self.config.clone();
145
146 RustlsAcceptorFuture::new(inner_future, config, self.handshake_timeout)
147 }
148}
149
150impl<A> fmt::Debug for RustlsAcceptor<A> {
151 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152 f.debug_struct("RustlsAcceptor").finish()
153 }
154}
155
156#[derive(Clone)]
158pub struct RustlsConfig {
159 inner: Arc<ArcSwap<ServerConfig>>,
160}
161
162impl RustlsConfig {
163 pub fn from_config(config: Arc<ServerConfig>) -> Self {
167 let inner = Arc::new(ArcSwap::new(config));
168
169 Self { inner }
170 }
171
172 pub async fn from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<Self> {
178 let server_config = spawn_blocking(|| config_from_der(cert, key))
179 .await
180 .unwrap()?;
181 let inner = Arc::new(ArcSwap::from_pointee(server_config));
182
183 Ok(Self { inner })
184 }
185
186 pub async fn from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<Self> {
190 let server_config = spawn_blocking(|| config_from_pem(cert, key))
191 .await
192 .unwrap()?;
193 let inner = Arc::new(ArcSwap::from_pointee(server_config));
194
195 Ok(Self { inner })
196 }
197
198 pub async fn from_pem_file(cert: impl AsRef<Path>, key: impl AsRef<Path>) -> io::Result<Self> {
202 let server_config = config_from_pem_file(cert, key).await?;
203 let inner = Arc::new(ArcSwap::from_pointee(server_config));
204
205 Ok(Self { inner })
206 }
207
208 pub fn get_inner(&self) -> Arc<ServerConfig> {
210 self.inner.load_full()
211 }
212
213 pub fn reload_from_config(&self, config: Arc<ServerConfig>) {
215 self.inner.store(config);
216 }
217
218 pub async fn reload_from_der(&self, cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<()> {
224 let server_config = spawn_blocking(|| config_from_der(cert, key))
225 .await
226 .unwrap()?;
227 let inner = Arc::new(server_config);
228
229 self.inner.store(inner);
230
231 Ok(())
232 }
233
234 pub async fn from_pem_chain_file(
237 chain: impl AsRef<Path>,
238 key: impl AsRef<Path>,
239 ) -> io::Result<Self> {
240 let server_config = config_from_pem_chain_file(chain, key).await?;
241 let inner = Arc::new(ArcSwap::from_pointee(server_config));
242
243 Ok(Self { inner })
244 }
245
246 pub async fn reload_from_pem(&self, cert: Vec<u8>, key: Vec<u8>) -> io::Result<()> {
250 let server_config = spawn_blocking(|| config_from_pem(cert, key))
251 .await
252 .unwrap()?;
253 let inner = Arc::new(server_config);
254
255 self.inner.store(inner);
256
257 Ok(())
258 }
259
260 pub async fn reload_from_pem_file(
264 &self,
265 cert: impl AsRef<Path>,
266 key: impl AsRef<Path>,
267 ) -> io::Result<()> {
268 let server_config = config_from_pem_file(cert, key).await?;
269 let inner = Arc::new(server_config);
270
271 self.inner.store(inner);
272
273 Ok(())
274 }
275}
276
277impl fmt::Debug for RustlsConfig {
278 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279 f.debug_struct("RustlsConfig").finish()
280 }
281}
282
283fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig> {
284 let cert = cert.into_iter().map(CertificateDer::from).collect();
285 let key = PrivateKeyDer::try_from(key).map_err(io_other)?;
286
287 let mut config = ServerConfig::builder()
288 .with_no_client_auth()
289 .with_single_cert(cert, key)
290 .map_err(io_other)?;
291
292 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
293
294 Ok(config)
295}
296
297fn config_from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<ServerConfig> {
298 let cert = rustls_pemfile::certs(&mut cert.as_ref())
299 .map(|it| it.map(|it| it.to_vec()))
300 .collect::<Result<Vec<_>, _>>()?;
301 let mut key_vec: Vec<Vec<u8>> = rustls_pemfile::read_all(&mut key.as_ref())
303 .filter_map(|i| match i.ok()? {
304 Item::Sec1Key(key) => Some(key.secret_sec1_der().to_vec()),
305 Item::Pkcs1Key(key) => Some(key.secret_pkcs1_der().to_vec()),
306 Item::Pkcs8Key(key) => Some(key.secret_pkcs8_der().to_vec()),
307 _ => None,
308 })
309 .collect();
310
311 if key_vec.len() != 1 {
313 return Err(io_other("private key format not supported"));
314 }
315
316 config_from_der(cert, key_vec.pop().unwrap())
317}
318
319async fn config_from_pem_file(
320 cert: impl AsRef<Path>,
321 key: impl AsRef<Path>,
322) -> io::Result<ServerConfig> {
323 let cert = fs_err::tokio::read(cert.as_ref()).await?;
324 let key = fs_err::tokio::read(key.as_ref()).await?;
325
326 config_from_pem(cert, key)
327}
328
329async fn config_from_pem_chain_file(
330 cert: impl AsRef<Path>,
331 chain: impl AsRef<Path>,
332) -> io::Result<ServerConfig> {
333 let cert = fs_err::tokio::read(cert.as_ref()).await?;
334 let cert = rustls_pemfile::certs(&mut cert.as_ref())
335 .map(|it| it.map(|it| CertificateDer::from(it.to_vec())))
336 .collect::<Result<Vec<_>, _>>()?;
337 let key = fs_err::tokio::read(chain.as_ref()).await?;
338 let key_cert: PrivateKeyDer = match rustls_pemfile::read_one(&mut key.as_ref())?
339 .ok_or_else(|| io_other("could not parse pem file"))?
340 {
341 Item::Pkcs8Key(key) => Ok(key.into()),
342 Item::Sec1Key(key) => Ok(key.into()),
343 Item::Pkcs1Key(key) => Ok(key.into()),
344 x => Err(io_other(format!(
345 "invalid certificate format, received: {x:?}"
346 ))),
347 }?;
348
349 ServerConfig::builder()
350 .with_no_client_auth()
351 .with_single_cert(cert, key_cert)
352 .map_err(|_| io_other("invalid certificate"))
353}
354
355#[cfg(test)]
356mod tests {
357 use crate::handle::Handle;
358 use crate::tls_rustls::{self, RustlsConfig};
359 use axum::body::Body;
360 use axum::routing::get;
361 use axum::Router;
362 use bytes::Bytes;
363 use http::{response, Request};
364 use http_body_util::BodyExt;
365 use hyper::client::conn::http1::{handshake, SendRequest};
366 use hyper_util::rt::TokioIo;
367 use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
368 use rustls::{ClientConfig, DigitallySignedStruct, Error, SignatureScheme};
369 use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
370 use std::fmt::Debug;
371 use std::{convert::TryFrom, io, net::SocketAddr, sync::Arc, time::Duration};
372 use tokio::time::sleep;
373 use tokio::{net::TcpStream, task::JoinHandle};
374 use tokio_rustls::TlsConnector;
375
376 #[tokio::test]
377 async fn start_and_request() {
378 let (_handle, _server_task, addr) = start_server().await;
379
380 let (mut client, _conn) = connect(addr).await;
381
382 let (_parts, body) = send_empty_request(&mut client).await;
383
384 assert_eq!(body.as_ref(), b"Hello, world!");
385 }
386
387 #[ignore]
388 #[tokio::test]
389 async fn tls_timeout() {
390 let (handle, _server_task, addr) = start_server().await;
391 assert_eq!(handle.connection_count(), 0);
392
393 let _stream = TcpStream::connect(addr).await.unwrap();
395
396 sleep(Duration::from_millis(500)).await;
397 assert_eq!(handle.connection_count(), 1);
398
399 tokio::time::sleep(Duration::from_millis(1000)).await;
400 assert_eq!(handle.connection_count(), 0);
402 }
403
404 #[tokio::test]
405 async fn test_reload() {
406 let handle = Handle::new();
407
408 let config = RustlsConfig::from_pem_file(
409 "examples/self-signed-certs/cert.pem",
410 "examples/self-signed-certs/key.pem",
411 )
412 .await
413 .unwrap();
414
415 let server_handle = handle.clone();
416 let rustls_config = config.clone();
417 tokio::spawn(async move {
418 let app = Router::new().route("/", get(|| async { "Hello, world!" }));
419
420 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
421
422 tls_rustls::bind_rustls(addr, rustls_config)
423 .handle(server_handle)
424 .serve(app.into_make_service())
425 .await
426 });
427
428 let addr = handle.listening().await.unwrap();
429
430 let cert_a = get_first_cert(addr).await;
431 let mut cert_b = get_first_cert(addr).await;
432
433 assert_eq!(cert_a, cert_b);
434
435 config
436 .reload_from_pem_file(
437 "examples/self-signed-certs/reload/cert.pem",
438 "examples/self-signed-certs/reload/key.pem",
439 )
440 .await
441 .unwrap();
442
443 cert_b = get_first_cert(addr).await;
444
445 assert_ne!(cert_a, cert_b);
446
447 config
448 .reload_from_pem_file(
449 "examples/self-signed-certs/cert.pem",
450 "examples/self-signed-certs/key.pem",
451 )
452 .await
453 .unwrap();
454
455 cert_b = get_first_cert(addr).await;
456
457 assert_eq!(cert_a, cert_b);
458 }
459
460 async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
461 let handle = Handle::new();
462
463 let server_handle = handle.clone();
464 let server_task = tokio::spawn(async move {
465 let app = Router::new().route("/", get(|| async { "Hello, world!" }));
466
467 let config = RustlsConfig::from_pem_file(
468 "examples/self-signed-certs/cert.pem",
469 "examples/self-signed-certs/key.pem",
470 )
471 .await?;
472
473 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
474
475 tls_rustls::bind_rustls(addr, config)
476 .handle(server_handle)
477 .serve(app.into_make_service())
478 .await
479 });
480
481 let addr = handle.listening().await.unwrap();
482
483 (handle, server_task, addr)
484 }
485
486 async fn get_first_cert(addr: SocketAddr) -> CertificateDer<'static> {
487 let stream = TcpStream::connect(addr).await.unwrap();
488 let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();
489
490 let (_io, client_connection) = tls_stream.into_inner();
491
492 client_connection.peer_certificates().unwrap()[0].clone()
493 }
494
495 async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
496 let stream = TcpStream::connect(addr).await.unwrap();
497 let tls_stream = TokioIo::new(tls_connector().connect(dns_name(), stream).await.unwrap());
498
499 let (send_request, connection) = handshake(tls_stream).await.unwrap();
500
501 let task = tokio::spawn(async move {
502 let _ = connection.await;
503 });
504
505 (send_request, task)
506 }
507
508 async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
509 let (parts, body) = client
510 .send_request(Request::new(Body::empty()))
511 .await
512 .unwrap()
513 .into_parts();
514 let body = body.collect().await.unwrap().to_bytes();
515
516 (parts, body)
517 }
518
519 fn tls_connector() -> TlsConnector {
520 #[derive(Debug)]
521 struct NoVerify;
522
523 impl ServerCertVerifier for NoVerify {
524 fn verify_server_cert(
525 &self,
526 _end_entity: &CertificateDer,
527 _intermediates: &[CertificateDer],
528 _server_name: &ServerName,
529 _ocsp_response: &[u8],
530 _now: UnixTime,
531 ) -> Result<ServerCertVerified, rustls::Error> {
532 Ok(ServerCertVerified::assertion())
533 }
534
535 fn verify_tls12_signature(
536 &self,
537 _message: &[u8],
538 _cert: &CertificateDer<'_>,
539 _dss: &DigitallySignedStruct,
540 ) -> Result<HandshakeSignatureValid, Error> {
541 Ok(HandshakeSignatureValid::assertion())
542 }
543
544 fn verify_tls13_signature(
545 &self,
546 _message: &[u8],
547 _cert: &CertificateDer<'_>,
548 _dss: &DigitallySignedStruct,
549 ) -> Result<HandshakeSignatureValid, Error> {
550 Ok(HandshakeSignatureValid::assertion())
551 }
552
553 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
554 vec![
555 SignatureScheme::RSA_PKCS1_SHA1,
556 SignatureScheme::RSA_PKCS1_SHA256,
557 SignatureScheme::RSA_PKCS1_SHA384,
558 SignatureScheme::RSA_PKCS1_SHA512,
559 SignatureScheme::RSA_PSS_SHA256,
560 SignatureScheme::RSA_PSS_SHA384,
561 SignatureScheme::RSA_PSS_SHA512,
562 ]
563 }
564 }
565
566 let mut client_config = ClientConfig::builder()
567 .dangerous()
568 .with_custom_certificate_verifier(Arc::new(NoVerify))
569 .with_no_client_auth();
570
571 client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
572
573 TlsConnector::from(Arc::new(client_config))
574 }
575
576 fn dns_name() -> ServerName<'static> {
577 ServerName::try_from("localhost").unwrap()
578 }
579
580 #[tokio::test]
581 async fn from_pem_file_not_found() {
582 let err = RustlsConfig::from_pem_file(
583 "examples/self-signed-certs/missing.pem",
584 "examples/self-signed-certs/key.pem",
585 )
586 .await
587 .unwrap_err();
588 assert_eq!(err.kind(), io::ErrorKind::NotFound);
589 assert_eq!(err.to_string(), "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)");
590
591 let err = RustlsConfig::from_pem_file(
592 "examples/self-signed-certs/cert.pem",
593 "examples/self-signed-certs/missing.pem",
594 )
595 .await
596 .unwrap_err();
597 assert_eq!(err.kind(), io::ErrorKind::NotFound);
598 assert_eq!(err.to_string(), "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)");
599 }
600
601 #[tokio::test]
602 async fn from_pem_file_chain_file_not_found() {
603 let err = RustlsConfig::from_pem_chain_file(
604 "examples/self-signed-certs/missing.pem",
605 "examples/self-signed-certs/key.pem",
606 )
607 .await
608 .unwrap_err();
609 assert_eq!(err.kind(), io::ErrorKind::NotFound);
610 assert_eq!(err.to_string(), "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)");
611
612 let err = RustlsConfig::from_pem_chain_file(
613 "examples/self-signed-certs/cert.pem",
614 "examples/self-signed-certs/missing.pem",
615 )
616 .await
617 .unwrap_err();
618 assert_eq!(err.kind(), io::ErrorKind::NotFound);
619 assert_eq!(err.to_string(), "failed to read from file `examples/self-signed-certs/missing.pem`: No such file or directory (os error 2)");
620 }
621}