use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use arc_swap::ArcSwap;
use certon::{
AcmeIssuer, CertResolver, Certificate, Config as CertAutoConfig, FileStorage, OnDemandConfig,
Storage,
};
use rustls::RootCertStore;
use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey;
use tokio::task::JoinHandle;
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, warn};
use crate::ProxyError;
use crate::config::{
AcmeConfig, AppConfig, CertAuthority, ChallengeType, ClientAuthConfig, DnsProviderConfig,
OnDemandTlsConfig, SiteTlsConfig, TlsConfig,
};
pub struct TlsManager {
certon_config: Option<Arc<CertAutoConfig>>,
resolver: Arc<CompositeResolver>,
server_config: ArcSwap<rustls::ServerConfig>,
maintenance_handle: Option<JoinHandle<()>>,
challenge_map: Arc<tokio::sync::RwLock<HashMap<String, String>>>,
}
impl TlsManager {
pub async fn build(config: &AppConfig) -> Result<Self, ProxyError> {
let challenge_map: Arc<tokio::sync::RwLock<HashMap<String, String>>> =
Arc::new(tokio::sync::RwLock::new(HashMap::new()));
let mut manual_certs: HashMap<String, Arc<CertifiedKey>> = HashMap::new();
let mut acme_domains: Vec<String> = Vec::new();
for site in &config.sites {
if let Some(ref site_tls) = site.tls {
match load_manual_cert(site_tls) {
Ok(certified_key) => {
info!(
host = %site.host,
cert = %site_tls.cert,
"loaded manual TLS certificate"
);
manual_certs.insert(site.host.clone(), certified_key);
}
Err(e) => {
error!(
host = %site.host,
cert = %site_tls.cert,
key = %site_tls.key,
error = %e,
"failed to load manual TLS certificate"
);
return Err(ProxyError::Internal(format!(
"failed to load TLS certificate for {}: {e}",
site.host
)));
}
}
} else if config.tls.is_some() {
acme_domains.push(site.host.clone());
}
}
let (certon_config, cert_resolver, maintenance_handle) =
if let Some(ref tls_config) = config.tls {
if let Some(ref acme_config) = tls_config.acme {
let (ca_config, resolver, handle) = setup_acme(
acme_config,
&acme_domains,
challenge_map.clone(),
tls_config.on_demand.as_ref(),
&config.sites,
)
.await?;
(Some(Arc::new(ca_config)), Some(resolver), Some(handle))
} else {
(None, None, None)
}
} else {
(None, None, None)
};
let composite = Arc::new(CompositeResolver {
manual_certs: tokio::sync::RwLock::new(manual_certs),
acme_resolver: cert_resolver,
});
let client_verifier = if let Some(ref tls_config) = config.tls {
if let Some(ref client_auth) = tls_config.client_auth {
Some(build_client_verifier(client_auth)?)
} else {
None
}
} else {
None
};
let rustls_config = build_server_config(
composite.clone(),
client_verifier.as_ref(),
config.tls.as_ref(),
);
if let Some(ref tls_config) = config.tls
&& tls_config.ocsp_stapling
{
info!(
"OCSP stapling enabled (handled by certon for ACME certs; \
manual certs require AIA extension parsing)"
);
}
Ok(Self {
certon_config,
resolver: composite,
server_config: ArcSwap::from_pointee(rustls_config),
maintenance_handle,
challenge_map,
})
}
pub fn acceptor(&self) -> TlsAcceptor {
let config = self.server_config.load_full();
TlsAcceptor::from(config)
}
pub fn server_config(&self) -> Arc<rustls::ServerConfig> {
self.server_config.load_full()
}
pub fn challenge_map(&self) -> Arc<tokio::sync::RwLock<HashMap<String, String>>> {
self.challenge_map.clone()
}
pub async fn reload(&self, config: &AppConfig) -> Result<(), ProxyError> {
info!("reloading TLS configuration");
let mut new_manual: HashMap<String, Arc<CertifiedKey>> = HashMap::new();
let mut acme_domains: Vec<String> = Vec::new();
for site in &config.sites {
if let Some(ref site_tls) = site.tls {
match load_manual_cert(site_tls) {
Ok(certified_key) => {
info!(
host = %site.host,
cert = %site_tls.cert,
"reloaded manual TLS certificate"
);
new_manual.insert(site.host.clone(), certified_key);
}
Err(e) => {
error!(
host = %site.host,
error = %e,
"failed to reload manual TLS certificate"
);
return Err(ProxyError::Internal(format!(
"failed to reload TLS certificate for {}: {e}",
site.host
)));
}
}
} else if config.tls.is_some() {
acme_domains.push(site.host.clone());
}
}
{
let mut guard = self.resolver.manual_certs.write().await;
*guard = new_manual;
}
if !acme_domains.is_empty()
&& let Some(ref ca_config) = self.certon_config
&& let Err(e) = ca_config.manage_sync(&acme_domains).await
{
warn!(
error = %e,
"failed to manage new ACME domains during reload"
);
}
let client_verifier = if let Some(ref tls_config) = config.tls {
if let Some(ref client_auth) = tls_config.client_auth {
match build_client_verifier(client_auth) {
Ok(v) => Some(v),
Err(e) => {
error!(error = %e, "failed to rebuild client cert verifier during reload");
None
}
}
} else {
None
}
} else {
None
};
let new_config = build_server_config(
self.resolver.clone(),
client_verifier.as_ref(),
config.tls.as_ref(),
);
self.server_config.store(Arc::new(new_config));
info!("TLS configuration reloaded successfully");
Ok(())
}
pub fn stop_maintenance(&self) {
if let Some(ref ca_config) = self.certon_config {
ca_config.cache.stop();
info!("certon maintenance loop stopped");
}
}
}
impl Drop for TlsManager {
fn drop(&mut self) {
if let Some(ref ca_config) = self.certon_config {
ca_config.cache.stop();
}
if let Some(ref handle) = self.maintenance_handle {
handle.abort();
}
}
}
struct CompositeResolver {
manual_certs: tokio::sync::RwLock<HashMap<String, Arc<CertifiedKey>>>,
acme_resolver: Option<Arc<CertResolver>>,
}
impl std::fmt::Debug for CompositeResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompositeResolver")
.field("manual_certs", &"<RwLock<HashMap>>")
.field("acme_resolver", &self.acme_resolver.as_ref().map(|_| "..."))
.finish()
}
}
impl ResolvesServerCert for CompositeResolver {
fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
if let Some(sni) = client_hello.server_name()
&& let Ok(guard) = self.manual_certs.try_read()
&& let Some(ck) = guard.get(sni)
{
debug!(sni = %sni, "serving manual TLS certificate");
return Some(ck.clone());
}
if let Some(ref resolver) = self.acme_resolver {
return resolver.resolve(client_hello);
}
None
}
}
fn build_client_verifier(
client_auth: &ClientAuthConfig,
) -> Result<Arc<dyn rustls::server::danger::ClientCertVerifier>, ProxyError> {
let mut root_store = RootCertStore::empty();
for ca_path in &client_auth.ca_certs {
let pem_data = fs::read(ca_path).map_err(|e| {
ProxyError::Internal(format!("failed to read CA cert file {ca_path}: {e}"))
})?;
let mut reader = std::io::BufReader::new(pem_data.as_slice());
let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
ProxyError::Internal(format!(
"failed to parse PEM certificates from {ca_path}: {e}"
))
})?;
if certs.is_empty() {
return Err(ProxyError::Internal(format!(
"no certificates found in CA file: {ca_path}"
)));
}
for cert in certs {
root_store.add(cert).map_err(|e| {
ProxyError::Internal(format!("failed to add CA cert from {ca_path}: {e}"))
})?;
}
info!(path = %ca_path, "loaded CA certificate for client auth");
}
let builder = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store));
let verifier = if client_auth.required {
info!("mTLS: client certificates required");
builder.build().map_err(|e| {
ProxyError::Internal(format!("failed to build client cert verifier: {e}"))
})?
} else {
info!("mTLS: client certificates optional");
builder.allow_unauthenticated().build().map_err(|e| {
ProxyError::Internal(format!(
"failed to build optional client cert verifier: {e}"
))
})?
};
Ok(verifier)
}
async fn setup_acme(
acme_config: &AcmeConfig,
domains: &[String],
challenge_map: Arc<tokio::sync::RwLock<HashMap<String, String>>>,
on_demand_config: Option<&OnDemandTlsConfig>,
sites: &[crate::config::SiteConfig],
) -> Result<(CertAutoConfig, Arc<CertResolver>, JoinHandle<()>), ProxyError> {
let storage: Arc<dyn Storage> = Arc::new(FileStorage::default());
let ca_url = match acme_config.ca {
CertAuthority::LetsEncrypt => certon::LETS_ENCRYPT_PRODUCTION,
CertAuthority::LetsEncryptStaging => certon::LETS_ENCRYPT_STAGING,
CertAuthority::ZeroSsl => certon::ZEROSSL_PRODUCTION,
};
let mut issuer_builder = AcmeIssuer::builder()
.ca(ca_url)
.email(&acme_config.email)
.agreed(true)
.storage(storage.clone());
match acme_config.challenge {
ChallengeType::Http01 => {
let solver = Arc::new(SharedMapHttp01Solver {
challenges: challenge_map,
});
issuer_builder = issuer_builder
.http01_solver(solver)
.disable_tlsalpn_challenge(true);
}
ChallengeType::TlsAlpn01 => {
issuer_builder = issuer_builder.disable_http_challenge(true);
}
ChallengeType::Dns01 => {
if let Some(ref dns_cfg) = acme_config.dns_provider {
let provider = create_dns_provider(dns_cfg)?;
let dns_solver = certon::Dns01Solver::new(provider);
issuer_builder = issuer_builder
.dns01_solver(Arc::new(dns_solver))
.disable_http_challenge(true)
.disable_tlsalpn_challenge(true);
} else {
return Err(ProxyError::Internal(
"DNS-01 challenge requires a dns-provider configuration".into(),
));
}
}
}
if let Some(ref eab_cfg) = acme_config.eab {
let hmac_bytes = base64_decode_hmac(&eab_cfg.hmac_key)?;
issuer_builder =
issuer_builder.external_account(certon::acme_client::ExternalAccountBinding {
kid: eab_cfg.kid.clone(),
hmac_key: hmac_bytes,
});
}
let issuer = Arc::new(issuer_builder.build());
let mut config_builder = CertAutoConfig::builder()
.storage(storage)
.issuers(vec![issuer.clone()]);
if let Some(od_config) = on_demand_config {
let on_demand = build_on_demand_config(od_config, &issuer, sites)?;
config_builder = config_builder.on_demand(Arc::new(on_demand));
info!("on-demand TLS configured");
}
let ca_config = config_builder.build();
if !domains.is_empty() {
info!(domains = ?domains, "managing ACME certificates");
ca_config.manage_sync(domains).await.map_err(|e| {
ProxyError::Internal(format!("ACME certificate management failed: {e}"))
})?;
}
let resolver = if ca_config.on_demand.is_some() {
let on_demand = ca_config.on_demand.clone().unwrap();
Arc::new(CertResolver::with_on_demand(
ca_config.cache.clone(),
on_demand,
))
} else {
Arc::new(CertResolver::new(ca_config.cache.clone()))
};
let maintenance_handle = certon::start_maintenance(&ca_config);
info!("certon maintenance loop started");
Ok((ca_config, resolver, maintenance_handle))
}
fn build_on_demand_config(
od_config: &OnDemandTlsConfig,
issuer: &Arc<certon::AcmeIssuer>,
sites: &[crate::config::SiteConfig],
) -> Result<OnDemandConfig, ProxyError> {
let allowlist: HashSet<String> = sites.iter().map(|s| s.host.to_lowercase()).collect();
type DecisionFn = dyn Fn(&str) -> bool + Send + Sync;
let decision_func: Option<Arc<DecisionFn>> = if let Some(ref ask_url) = od_config.ask {
let url = ask_url.clone();
Some(Arc::new(move |domain: &str| {
check_ask_url_blocking(&url, domain)
}))
} else {
None
};
let rate_limit = od_config.rate_limit.map(|max_per_minute| {
Arc::new(certon::rate_limiter::RateLimiter::new(
max_per_minute as usize,
Duration::from_secs(60),
))
});
let issuer_for_obtain = Arc::clone(issuer);
type ObtainFn = dyn Fn(String) -> Pin<Box<dyn std::future::Future<Output = certon::Result<()>> + Send>>
+ Send
+ Sync;
let obtain_func: Arc<ObtainFn> = Arc::new(move |domain: String| {
let issuer = Arc::clone(&issuer_for_obtain);
Box::pin(async move {
info!(domain = %domain, "on-demand TLS: obtaining certificate");
match issuer
.issue_for_domains(std::slice::from_ref(&domain))
.await
{
Ok(_cert) => {
info!(domain = %domain, "on-demand TLS: certificate obtained");
Ok(())
}
Err(e) => {
error!(domain = %domain, error = %e, "on-demand TLS: failed to obtain certificate");
Err(e)
}
}
})
});
Ok(OnDemandConfig {
decision_func,
host_allowlist: if allowlist.is_empty() {
None
} else {
Some(allowlist)
},
rate_limit,
obtain_func: Some(obtain_func),
})
}
fn check_ask_url_blocking(ask_url: &str, domain: &str) -> bool {
let url = format!("{}?domain={}", ask_url, domain);
debug!(url = %url, "on-demand TLS: checking ask URL");
let result = std::panic::catch_unwind(|| {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.build()
.map_err(|e| format!("client build error: {e}"))?;
let resp = client
.get(&url)
.send()
.await
.map_err(|e| format!("request error: {e}"))?;
Ok::<bool, String>(resp.status().is_success())
})
})
});
match result {
Ok(Ok(allowed)) => {
debug!(url = %url, allowed = %allowed, "on-demand TLS: ask URL response");
allowed
}
Ok(Err(e)) => {
warn!(url = %url, error = %e, "on-demand TLS: ask URL request failed");
false
}
Err(_) => {
warn!(url = %url, "on-demand TLS: ask URL check failed (no runtime)");
false
}
}
}
struct SharedMapHttp01Solver {
challenges: Arc<tokio::sync::RwLock<HashMap<String, String>>>,
}
#[async_trait::async_trait]
impl certon::Solver for SharedMapHttp01Solver {
async fn present(&self, _domain: &str, token: &str, key_auth: &str) -> certon::Result<()> {
debug!(token = %token, "presenting HTTP-01 challenge token");
let mut map = self.challenges.write().await;
map.insert(token.to_string(), key_auth.to_string());
Ok(())
}
async fn cleanup(&self, _domain: &str, token: &str, _key_auth: &str) -> certon::Result<()> {
debug!(token = %token, "cleaning up HTTP-01 challenge token");
let mut map = self.challenges.write().await;
map.remove(token);
Ok(())
}
}
fn load_manual_cert(site_tls: &SiteTlsConfig) -> Result<Arc<CertifiedKey>, ProxyError> {
let cert_path = Path::new(&site_tls.cert);
let key_path = Path::new(&site_tls.key);
let cert = Certificate::from_pem_files(cert_path, key_path).map_err(|e| {
ProxyError::Internal(format!(
"failed to parse PEM certificate/key ({}, {}): {e}",
site_tls.cert, site_tls.key
))
})?;
let certified_key = certon::handshake::cert_to_certified_key(&cert).map_err(|e| {
ProxyError::Internal(format!(
"failed to convert certificate to CertifiedKey: {e}"
))
})?;
Ok(certified_key)
}
fn resolve_protocol_versions(
min_version: Option<&str>,
max_version: Option<&str>,
) -> Option<Vec<&'static rustls::SupportedProtocolVersion>> {
if min_version.is_none() && max_version.is_none() {
return None;
}
let all: [&'static rustls::SupportedProtocolVersion; 2] =
[&rustls::version::TLS12, &rustls::version::TLS13];
let version_index = |v: &str| match v {
"1.2" => Some(0usize),
"1.3" => Some(1usize),
_ => None,
};
let min_idx: usize = min_version.and_then(version_index).unwrap_or(0);
let max_idx: usize = max_version.and_then(version_index).unwrap_or(all.len() - 1);
let versions: Vec<&'static rustls::SupportedProtocolVersion> = all[min_idx..=max_idx].to_vec();
if versions.is_empty() {
None
} else {
Some(versions)
}
}
fn build_provider_with_suites(suite_names: &[String]) -> rustls::crypto::CryptoProvider {
use rustls::crypto::ring::cipher_suite;
let all_suites: &[rustls::SupportedCipherSuite] = &[
cipher_suite::TLS13_AES_256_GCM_SHA384,
cipher_suite::TLS13_AES_128_GCM_SHA256,
cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
];
let selected: Vec<rustls::SupportedCipherSuite> = all_suites
.iter()
.filter(|s: &&rustls::SupportedCipherSuite| {
let name = format!("{:?}", s.suite());
suite_names
.iter()
.any(|n| name.contains(n.as_str()) || n.as_str() == name.as_str())
})
.copied()
.collect();
let suites = if selected.is_empty() {
warn!(
"no matching cipher suites found for {:?}, using defaults",
suite_names
);
rustls::crypto::ring::default_provider().cipher_suites
} else {
selected
};
rustls::crypto::CryptoProvider {
cipher_suites: suites,
..rustls::crypto::ring::default_provider()
}
}
fn filter_kx_groups(curve_names: &[String]) -> Vec<&'static dyn rustls::crypto::SupportedKxGroup> {
use rustls::crypto::ring::kx_group;
let all: &[(&str, &'static dyn rustls::crypto::SupportedKxGroup)] = &[
("x25519", kx_group::X25519),
("secp256r1", kx_group::SECP256R1),
("secp384r1", kx_group::SECP384R1),
];
let mut selected: Vec<&'static dyn rustls::crypto::SupportedKxGroup> = Vec::new();
for name in curve_names {
let lower = name.to_ascii_lowercase();
for &(n, group) in all {
if n == lower {
selected.push(group);
}
}
}
if selected.is_empty() {
warn!(
"no matching ECDH curves for {:?}, using defaults",
curve_names
);
rustls::crypto::ring::default_provider().kx_groups
} else {
selected
}
}
fn build_server_config(
resolver: Arc<dyn ResolvesServerCert>,
client_verifier: Option<&Arc<dyn rustls::server::danger::ClientCertVerifier>>,
tls_config: Option<&TlsConfig>,
) -> rustls::ServerConfig {
let provider = if let Some(cfg) = tls_config {
let mut p = if !cfg.cipher_suites.is_empty() {
build_provider_with_suites(&cfg.cipher_suites)
} else {
rustls::crypto::ring::default_provider()
};
if !cfg.ecdh_curves.is_empty() {
p.kx_groups = filter_kx_groups(&cfg.ecdh_curves);
}
Arc::new(p)
} else {
Arc::new(rustls::crypto::ring::default_provider())
};
let versions = tls_config.and_then(|cfg| {
resolve_protocol_versions(cfg.min_version.as_deref(), cfg.max_version.as_deref())
});
let builder = if let Some(ref versions) = versions {
rustls::ServerConfig::builder_with_provider(provider)
.with_protocol_versions(versions)
.expect("TLS protocol versions are valid")
} else {
rustls::ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.expect("default protocol versions are valid")
};
if let Some(verifier) = client_verifier {
builder
.with_client_cert_verifier(Arc::clone(verifier))
.with_cert_resolver(resolver)
} else {
builder.with_no_client_auth().with_cert_resolver(resolver)
}
}
fn base64_decode_hmac(input: &str) -> Result<Vec<u8>, ProxyError> {
let normalised: String = input
.trim()
.chars()
.map(|c| match c {
'-' => '+',
'_' => '/',
other => other,
})
.collect();
decode_base64(&normalised)
.ok_or_else(|| ProxyError::Internal("invalid base64 in EAB HMAC key".into()))
}
fn decode_base64(input: &str) -> Option<Vec<u8>> {
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let input = input.trim();
if input.is_empty() {
return Some(Vec::new());
}
let mut output = Vec::with_capacity(input.len() * 3 / 4);
let mut buf: u32 = 0;
let mut bits: u32 = 0;
for &b in input.as_bytes() {
if b == b'=' {
break;
}
let val = match TABLE.iter().position(|&c| c == b) {
Some(v) => v as u32,
None => {
if b == b'\n' || b == b'\r' || b == b' ' {
continue;
}
return None;
}
};
buf = (buf << 6) | val;
bits += 6;
if bits >= 8 {
bits -= 8;
output.push((buf >> bits) as u8);
buf &= (1 << bits) - 1;
}
}
Some(output)
}
fn create_dns_provider(
cfg: &DnsProviderConfig,
) -> Result<Box<dyn certon::DnsProvider>, ProxyError> {
match cfg.provider.as_str() {
"cloudflare" => Ok(Box::new(crate::tls::dns::CloudflareDns::new(cfg)?)),
"route53" => Ok(Box::new(crate::tls::dns::Route53Dns::new(cfg)?)),
"digitalocean" => Ok(Box::new(crate::tls::dns::DigitalOceanDns::new(cfg)?)),
"dnsimple" => Ok(Box::new(crate::tls::dns::DnSimpleDns::new(cfg)?)),
"porkbun" => Ok(Box::new(crate::tls::dns::PorkbunDns::new(cfg)?)),
"ovh" => Ok(Box::new(crate::tls::dns::OvhDns::new(cfg)?)),
"desec" => Ok(Box::new(crate::tls::dns::DesecDns::new(cfg)?)),
"bunny" => Ok(Box::new(crate::tls::dns::BunnyDns::new(cfg)?)),
"rfc2136" => Ok(Box::new(crate::tls::dns::Rfc2136Dns::new(cfg)?)),
other => Err(ProxyError::Internal(format!(
"unknown DNS provider: {other}"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_server_config_creates_valid_config_no_client_auth() {
#[derive(Debug)]
struct NullResolver;
impl ResolvesServerCert for NullResolver {
fn resolve(
&self,
_client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<CertifiedKey>> {
None
}
}
let config = build_server_config(Arc::new(NullResolver), None, None);
assert!(config.alpn_protocols.is_empty());
}
#[test]
fn build_server_config_with_client_verifier() {
#[derive(Debug)]
struct NullResolver;
impl ResolvesServerCert for NullResolver {
fn resolve(
&self,
_client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<CertifiedKey>> {
None
}
}
let config = build_server_config(Arc::new(NullResolver), None, None);
assert!(config.alpn_protocols.is_empty());
}
}