Skip to main content

edgeguard/
tls.rs

1//! TLS termination via `rustls` + `tokio-rustls`.
2//!
3//! axum 0.7 has no built-in TLS, so we run a small accept loop: take a TCP connection,
4//! complete the rustls handshake, then hand the encrypted stream to hyper, serving the same
5//! axum [`Router`] the plaintext path uses. Certificates come either from PEM files
6//! ([`load_server_config`]) or from ACME (which writes those same files; see [`crate::acme`]).
7
8use std::fs::File;
9use std::io::BufReader;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::Duration;
13
14use anyhow::{Context, Result};
15use axum::http::Request;
16use axum::Router;
17use hyper::body::Incoming;
18use hyper_util::rt::{TokioExecutor, TokioIo};
19use hyper_util::server::conn::auto::Builder as ConnBuilder;
20use rustls::pki_types::{CertificateDer, PrivateKeyDer};
21use rustls::ServerConfig;
22use tokio::net::TcpListener;
23use tokio::sync::watch;
24use tokio_rustls::TlsAcceptor;
25use tower::{Service, ServiceExt};
26use tracing::{debug, info, warn};
27
28/// Install a process-wide default crypto provider (ring). Idempotent and best-effort: if a
29/// provider is already installed (e.g. by the JWKS HTTP client) this is a no-op. Pinning one
30/// avoids the "no process-level CryptoProvider" ambiguity when multiple providers are linked.
31pub fn init_crypto() {
32    let _ = rustls::crypto::ring::default_provider().install_default();
33}
34
35/// Build a rustls [`ServerConfig`] from a PEM certificate chain and private key. Uses an
36/// explicit ring provider so it doesn't depend on which provider happens to be the process
37/// default. Advertises HTTP/1.1 via ALPN (the proxy speaks HTTP/1.1 upstream).
38pub fn load_server_config(cert_path: &str, key_path: &str) -> Result<Arc<ServerConfig>> {
39    let certs = load_certs(cert_path)?;
40    let key = load_key(key_path)?;
41
42    let mut config =
43        ServerConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
44            .with_safe_default_protocol_versions()
45            .context("selecting TLS protocol versions")?
46            .with_no_client_auth()
47            .with_single_cert(certs, key)
48            .context("building rustls ServerConfig (does the key match the certificate?)")?;
49    config.alpn_protocols = vec![b"http/1.1".to_vec()];
50    Ok(Arc::new(config))
51}
52
53fn load_certs(path: &str) -> Result<Vec<CertificateDer<'static>>> {
54    let file = File::open(path).with_context(|| format!("opening certificate file {path}"))?;
55    let mut reader = BufReader::new(file);
56    let certs = rustls_pemfile::certs(&mut reader)
57        .collect::<Result<Vec<_>, _>>()
58        .with_context(|| format!("parsing certificates from {path}"))?;
59    anyhow::ensure!(!certs.is_empty(), "no certificates found in {path}");
60    Ok(certs)
61}
62
63fn load_key(path: &str) -> Result<PrivateKeyDer<'static>> {
64    let file = File::open(path).with_context(|| format!("opening private key file {path}"))?;
65    let mut reader = BufReader::new(file);
66    rustls_pemfile::private_key(&mut reader)
67        .with_context(|| format!("parsing private key from {path}"))?
68        .with_context(|| format!("no private key found in {path}"))
69}
70
71/// Serve `app` over TLS on `listener` until `shutdown` flips true. Each connection is
72/// handshaked and served on its own task, so a slow handshake can't block new accepts and a
73/// graceful shutdown stops accepting while letting the listener drop.
74pub async fn serve(
75    listener: TcpListener,
76    config: Arc<ServerConfig>,
77    app: Router,
78    mut shutdown: watch::Receiver<bool>,
79) -> Result<()> {
80    let acceptor = TlsAcceptor::from(config);
81    // `into_make_service_with_connect_info` injects `ConnectInfo(peer)` per connection, which
82    // the proxy handler relies on for client-IP resolution.
83    let mut make_service = app.into_make_service_with_connect_info::<SocketAddr>();
84
85    info!(listen = %listener.local_addr().map(|a| a.to_string()).unwrap_or_default(), "TLS listener up");
86
87    loop {
88        let (stream, peer) = tokio::select! {
89            _ = shutdown.changed() => {
90                if *shutdown.borrow() { break; }
91                continue;
92            }
93            accepted = listener.accept() => match accepted {
94                Ok(v) => v,
95                Err(e) => { warn!(error = %e, "TLS accept error"); continue; }
96            },
97        };
98
99        let acceptor = acceptor.clone();
100        // Connection-scoped tower service carrying this peer's ConnectInfo.
101        let tower_service = unwrap_infallible(make_service.call(peer).await);
102
103        tokio::spawn(async move {
104            // Bound the handshake so a client that never completes it can't pin a task/socket
105            // indefinitely (this runs before any auth/rate-limit checks).
106            let tls_stream = match tokio::time::timeout(
107                Duration::from_secs(10),
108                acceptor.accept(stream),
109            )
110            .await
111            {
112                Ok(Ok(s)) => s,
113                Ok(Err(e)) => {
114                    debug!(error = %e, %peer, "TLS handshake failed");
115                    return;
116                }
117                Err(_) => {
118                    debug!(%peer, "TLS handshake timed out");
119                    return;
120                }
121            };
122            let io = TokioIo::new(tls_stream);
123            let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
124                tower_service.clone().oneshot(request)
125            });
126            if let Err(e) = ConnBuilder::new(TokioExecutor::new())
127                .serve_connection_with_upgrades(io, hyper_service)
128                .await
129            {
130                debug!(error = %e, %peer, "error serving TLS connection");
131            }
132        });
133    }
134    Ok(())
135}
136
137fn unwrap_infallible<T>(result: Result<T, std::convert::Infallible>) -> T {
138    match result {
139        Ok(value) => value,
140        Err(never) => match never {},
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn load_server_config_errors_on_missing_files() {
150        assert!(load_server_config("/no/such/cert.pem", "/no/such/key.pem").is_err());
151    }
152}