1use std::fs::File;
9use std::io::BufReader;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::Duration;
13
14use anyhow::{Context, Result};
15use axum::http::Request;
16use axum::Router;
17use hyper::body::Incoming;
18use hyper_util::rt::{TokioExecutor, TokioIo};
19use hyper_util::server::conn::auto::Builder as ConnBuilder;
20use rustls::pki_types::{CertificateDer, PrivateKeyDer};
21use rustls::ServerConfig;
22use tokio::net::TcpListener;
23use tokio::sync::watch;
24use tokio_rustls::TlsAcceptor;
25use tower::{Service, ServiceExt};
26use tracing::{debug, info, warn};
27
28pub fn init_crypto() {
32 let _ = rustls::crypto::ring::default_provider().install_default();
33}
34
35pub fn load_server_config(cert_path: &str, key_path: &str) -> Result<Arc<ServerConfig>> {
39 let certs = load_certs(cert_path)?;
40 let key = load_key(key_path)?;
41
42 let mut config =
43 ServerConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
44 .with_safe_default_protocol_versions()
45 .context("selecting TLS protocol versions")?
46 .with_no_client_auth()
47 .with_single_cert(certs, key)
48 .context("building rustls ServerConfig (does the key match the certificate?)")?;
49 config.alpn_protocols = vec![b"http/1.1".to_vec()];
50 Ok(Arc::new(config))
51}
52
53fn load_certs(path: &str) -> Result<Vec<CertificateDer<'static>>> {
54 let file = File::open(path).with_context(|| format!("opening certificate file {path}"))?;
55 let mut reader = BufReader::new(file);
56 let certs = rustls_pemfile::certs(&mut reader)
57 .collect::<Result<Vec<_>, _>>()
58 .with_context(|| format!("parsing certificates from {path}"))?;
59 anyhow::ensure!(!certs.is_empty(), "no certificates found in {path}");
60 Ok(certs)
61}
62
63fn load_key(path: &str) -> Result<PrivateKeyDer<'static>> {
64 let file = File::open(path).with_context(|| format!("opening private key file {path}"))?;
65 let mut reader = BufReader::new(file);
66 rustls_pemfile::private_key(&mut reader)
67 .with_context(|| format!("parsing private key from {path}"))?
68 .with_context(|| format!("no private key found in {path}"))
69}
70
71pub async fn serve(
75 listener: TcpListener,
76 config: Arc<ServerConfig>,
77 app: Router,
78 mut shutdown: watch::Receiver<bool>,
79) -> Result<()> {
80 let acceptor = TlsAcceptor::from(config);
81 let mut make_service = app.into_make_service_with_connect_info::<SocketAddr>();
84
85 info!(listen = %listener.local_addr().map(|a| a.to_string()).unwrap_or_default(), "TLS listener up");
86
87 loop {
88 let (stream, peer) = tokio::select! {
89 _ = shutdown.changed() => {
90 if *shutdown.borrow() { break; }
91 continue;
92 }
93 accepted = listener.accept() => match accepted {
94 Ok(v) => v,
95 Err(e) => { warn!(error = %e, "TLS accept error"); continue; }
96 },
97 };
98
99 let acceptor = acceptor.clone();
100 let tower_service = unwrap_infallible(make_service.call(peer).await);
102
103 tokio::spawn(async move {
104 let tls_stream = match tokio::time::timeout(
107 Duration::from_secs(10),
108 acceptor.accept(stream),
109 )
110 .await
111 {
112 Ok(Ok(s)) => s,
113 Ok(Err(e)) => {
114 debug!(error = %e, %peer, "TLS handshake failed");
115 return;
116 }
117 Err(_) => {
118 debug!(%peer, "TLS handshake timed out");
119 return;
120 }
121 };
122 let io = TokioIo::new(tls_stream);
123 let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
124 tower_service.clone().oneshot(request)
125 });
126 if let Err(e) = ConnBuilder::new(TokioExecutor::new())
127 .serve_connection_with_upgrades(io, hyper_service)
128 .await
129 {
130 debug!(error = %e, %peer, "error serving TLS connection");
131 }
132 });
133 }
134 Ok(())
135}
136
137fn unwrap_infallible<T>(result: Result<T, std::convert::Infallible>) -> T {
138 match result {
139 Ok(value) => value,
140 Err(never) => match never {},
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn load_server_config_errors_on_missing_files() {
150 assert!(load_server_config("/no/such/cert.pem", "/no/such/key.pem").is_err());
151 }
152}