use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::net::SocketAddr;
use std::sync::Arc;
use http_body_util::BodyExt;
use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey;
use rustls::ServerConfig;
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, warn};
use crate::config::TlsConfig;
fn load_certs(
path: &str,
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, std::io::Error> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let certs: Vec<_> = rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("No certificates found in {}", path),
));
}
Ok(certs)
}
fn load_key(path: &str) -> Result<rustls::pki_types::PrivateKeyDer<'static>, std::io::Error> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
rustls_pemfile::private_key(&mut reader)?.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("No private key found in {}", path),
)
})
}
struct SniCertResolver {
entries: HashMap<String, Arc<CertifiedKey>>,
default: Option<Arc<CertifiedKey>>,
}
impl std::fmt::Debug for SniCertResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SniCertResolver")
.field("entries", &self.entries.keys().collect::<Vec<_>>())
.finish()
}
}
impl SniCertResolver {
fn new(
entries: HashMap<String, Arc<CertifiedKey>>,
default: Option<Arc<CertifiedKey>>,
) -> Self {
Self { entries, default }
}
}
impl ResolvesServerCert for SniCertResolver {
fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
if let Some(sni) = client_hello.server_name() {
if let Some(cert) = self.entries.get(&sni.to_ascii_lowercase()) {
return Some(Arc::clone(cert));
}
}
self.default.clone()
}
}
pub fn build_tls_acceptor(
sites: &[(String, TlsConfig)], default_hostname: Option<&str>,
) -> Result<TlsAcceptor, std::io::Error> {
if sites.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Cannot build TLS acceptor: no TLS sites provided",
));
}
let mut entries: HashMap<String, Arc<CertifiedKey>> = HashMap::new();
let mut default: Option<Arc<CertifiedKey>> = None;
for (hostname, tls_cfg) in sites {
let certs = load_certs(&tls_cfg.cert_path)?;
let key = load_key(&tls_cfg.key_path)?;
let signing_key =
rustls::crypto::aws_lc_rs::sign::any_supported_type(&key).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Unsupported key type in {} for host {}: {}",
tls_cfg.key_path, hostname, e
),
)
})?;
let certified_key = Arc::new(CertifiedKey::new(certs, signing_key));
if default_hostname.is_some()
&& hostname.eq_ignore_ascii_case(default_hostname.unwrap_or(""))
{
default = Some(Arc::clone(&certified_key));
}
entries.insert(hostname.to_ascii_lowercase(), certified_key);
}
if default.is_none() {
let mut hostnames: Vec<_> = entries.keys().cloned().collect();
hostnames.sort();
if let Some(first) = hostnames.first() {
default = entries.get(first).cloned();
}
}
let resolver = Arc::new(SniCertResolver::new(entries, default));
let config = ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(resolver);
Ok(TlsAcceptor::from(Arc::new(config)))
}
pub async fn listen_tls<F, Fut>(
addr: SocketAddr,
acceptor: TlsAcceptor,
semaphore: Arc<tokio::sync::Semaphore>,
make_service: F,
) -> anyhow::Result<()>
where
F: Fn(hyper::Request<hyper::body::Incoming>, std::net::SocketAddr) -> Fut
+ Clone
+ Send
+ 'static,
Fut: std::future::Future<
Output = Result<
hyper::Response<
http_body_util::combinators::BoxBody<
bytes::Bytes,
Box<dyn std::error::Error + Send + Sync>,
>,
>,
anyhow::Error,
>,
> + Send,
{
let listener = TcpListener::bind(addr).await?;
info!("TLS listener bound on https://{}", addr);
loop {
let (tcp_stream, remote_addr) = listener.accept().await?;
let io = hyper_util::rt::TokioIo::new(tcp_stream);
let acceptor = acceptor.clone();
let semaphore = semaphore.clone();
let make_service = make_service.clone();
let permit = match semaphore.try_acquire_owned() {
Ok(p) => p,
Err(_) => {
warn!(
"TLS concurrency limit exceeded, rejecting connection from {}",
remote_addr
);
continue;
}
};
tokio::task::spawn(async move {
let _permit = permit;
let tls_stream = match acceptor.accept(io.into_inner()).await {
Ok(s) => s,
Err(e) => {
info!("TLS handshake failed from {}: {}", remote_addr, e);
return;
}
};
let io = hyper_util::rt::TokioIo::new(tls_stream);
let make_service = make_service.clone();
let service = hyper::service::service_fn(move |req| {
let make_service = make_service.clone();
make_service(req, remote_addr)
});
let mut builder = hyper::server::conn::http1::Builder::new();
builder.keep_alive(true).pipeline_flush(false);
if let Err(e) = builder.serve_connection(io, service).await {
error!("TLS connection error from {}: {:?}", remote_addr, e);
}
});
}
}
fn hostname_from_host_header(host: &str) -> &str {
if host.starts_with('[') {
if let Some(end) = host.find(']') {
return &host[1..end];
}
}
host.rsplit(':').next_back().unwrap_or(host)
}
pub async fn listen_http_redirect(addr: SocketAddr, tls_port: u16) -> anyhow::Result<()> {
let listener = TcpListener::bind(addr).await?;
info!(
"HTTP→HTTPS redirect server listening on http://{} → :{}",
addr, tls_port
);
loop {
let (stream, remote_addr) = listener.accept().await?;
let io = hyper_util::rt::TokioIo::new(stream);
tokio::task::spawn(async move {
let service =
hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
let host = req
.headers()
.get(hyper::header::HOST)
.and_then(|h| h.to_str().ok())
.unwrap_or("localhost");
let hostname = hostname_from_host_header(host);
let uri = req
.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
let location = if tls_port == 443 {
format!("https://{}{}", hostname, uri)
} else {
format!("https://{}:{}{}", hostname, tls_port, uri)
};
let body = format!(
"<!DOCTYPE html><html><head><title>301 Moved</title></head>\
<body><h1>Moved Permanently</h1><p>The document has moved \
<a href=\"{}\">here</a>.</p></body></html>",
location
);
let body_len = body.len();
let response = hyper::Response::builder()
.status(301)
.header("Location", &location)
.header("Content-Type", "text/html; charset=utf-8")
.header("Content-Length", body_len)
.body(
http_body_util::Full::new(bytes::Bytes::from(body))
.map_err(|e| {
Box::new(e) as Box<dyn std::error::Error + Send + Sync>
})
.boxed(),
)
.expect("static redirect response build");
std::future::ready(Ok::<_, hyper::Error>(response))
});
let mut builder = hyper::server::conn::http1::Builder::new();
builder.keep_alive(false).pipeline_flush(false);
if let Err(e) = builder.serve_connection(io, service).await {
debug!("Redirect connection error from {}: {:?}", remote_addr, e);
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_tls_acceptor_empty_sites() {
let result = build_tls_acceptor(&[], None);
assert!(result.is_err());
}
#[test]
fn test_hostname_from_host_header_ipv6() {
assert_eq!(hostname_from_host_header("[::1]:8080"), "::1");
assert_eq!(hostname_from_host_header("example.com:8080"), "example.com");
assert_eq!(hostname_from_host_header("example.com"), "example.com");
}
#[test]
fn test_load_certs_nonexistent_file() {
let result = load_certs("/nonexistent/path/cert.pem");
assert!(result.is_err());
}
#[test]
fn test_load_key_nonexistent_file() {
let result = load_key("/nonexistent/path/key.pem");
assert!(result.is_err());
}
}