use std::collections::{HashMap, HashSet};
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use rustls::pki_types::{PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer, PrivateSec1KeyDer};
use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use tokio::sync::{Mutex, Notify, RwLock};
type DecisionFunc = Arc<dyn Fn(&str) -> bool + Send + Sync>;
type ObtainFunc =
Arc<dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
type OcspRefreshFunc =
Arc<dyn Fn(String) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
use tracing::{debug, warn};
use crate::cache::CertCache;
use crate::certificates::{Certificate, PrivateKeyKind};
use crate::error::{CryptoError, Result};
use crate::rate_limiter::RateLimiter;
pub struct OnDemandConfig {
pub decision_func: Option<DecisionFunc>,
pub host_allowlist: Option<HashSet<String>>,
pub rate_limit: Option<Arc<RateLimiter>>,
pub obtain_func: Option<ObtainFunc>,
}
impl fmt::Debug for OnDemandConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OnDemandConfig")
.field("decision_func", &self.decision_func.as_ref().map(|_| "..."))
.field("host_allowlist", &self.host_allowlist)
.field("rate_limit", &self.rate_limit.as_ref().map(|_| "..."))
.field("obtain_func", &self.obtain_func.as_ref().map(|_| "..."))
.finish()
}
}
impl OnDemandConfig {
fn is_allowed(&self, name: &str) -> bool {
let lower = name.to_lowercase();
if let Some(ref func) = self.decision_func {
return func(&lower);
}
if let Some(ref allowlist) = self.host_allowlist {
return allowlist.contains(&lower);
}
debug!(
name = %name,
"on-demand TLS denied: no decision_func or host_allowlist configured",
);
false
}
}
pub struct CertResolver {
cache: Arc<CertCache>,
on_demand: Option<Arc<OnDemandConfig>>,
default_cert: RwLock<Option<Arc<CertifiedKey>>>,
default_server_name: Option<String>,
fallback_server_name: Option<String>,
acme_challenges: Arc<RwLock<HashMap<String, Arc<CertifiedKey>>>>,
pending_obtains: Arc<Mutex<HashMap<String, Arc<Notify>>>>,
pub ocsp_refresh_func: Option<OcspRefreshFunc>,
}
impl fmt::Debug for CertResolver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CertResolver")
.field("cache", &self.cache)
.field("on_demand", &self.on_demand.as_ref().map(|_| "..."))
.field("default_cert", &"<RwLock<Option<Arc<CertifiedKey>>>>")
.field("default_server_name", &self.default_server_name)
.field("fallback_server_name", &self.fallback_server_name)
.field("acme_challenges", &"<Arc<RwLock<HashMap>>>")
.finish()
}
}
impl CertResolver {
pub fn new(cache: Arc<CertCache>) -> Self {
Self {
cache,
on_demand: None,
default_cert: RwLock::new(None),
default_server_name: None,
fallback_server_name: None,
acme_challenges: Arc::new(RwLock::new(HashMap::new())),
pending_obtains: Arc::new(Mutex::new(HashMap::new())),
ocsp_refresh_func: None,
}
}
pub fn with_on_demand(cache: Arc<CertCache>, on_demand: Arc<OnDemandConfig>) -> Self {
Self {
cache,
on_demand: Some(on_demand),
default_cert: RwLock::new(None),
default_server_name: None,
fallback_server_name: None,
acme_challenges: Arc::new(RwLock::new(HashMap::new())),
pending_obtains: Arc::new(Mutex::new(HashMap::new())),
ocsp_refresh_func: None,
}
}
pub fn set_default_server_name(&mut self, name: Option<String>) {
self.default_server_name = name;
}
pub fn set_fallback_server_name(&mut self, name: Option<String>) {
self.fallback_server_name = name;
}
pub async fn set_challenge_cert(&self, domain: String, cert: Arc<CertifiedKey>) {
let mut challenges = self.acme_challenges.write().await;
challenges.insert(domain, cert);
}
pub async fn remove_challenge_cert(&self, domain: &str) {
let mut challenges = self.acme_challenges.write().await;
challenges.remove(domain);
}
pub async fn set_default_cert(&self, cert: Arc<CertifiedKey>) {
let mut guard = self.default_cert.write().await;
*guard = Some(cert);
}
pub async fn clear_default_cert(&self) {
let mut guard = self.default_cert.write().await;
*guard = None;
}
fn resolve_name(&self, server_name: Option<&str>) -> Option<Arc<CertifiedKey>> {
let name = match server_name {
Some(n) if !n.is_empty() => n.to_owned(),
_ => {
if let Some(ref default_name) = self.default_server_name {
debug!(
default_server_name = %default_name,
"no SNI in ClientHello; using default_server_name",
);
default_name.clone()
} else {
debug!("no SNI in ClientHello; returning default cert");
return self.try_default_cert();
}
}
};
debug!(sni = %name, "resolving certificate for TLS handshake");
if let Some(result) = self.try_resolve_from_cache_with_ocsp_check(&name) {
return Some(result);
}
{
let labels: Vec<&str> = name.split('.').collect();
if labels.len() >= 2 {
for i in 1..labels.len() {
let wildcard = format!("*.{}", labels[i..].join("."));
if let Some(result) = self.try_resolve_from_cache(&wildcard) {
debug!(
sni = %name,
wildcard = %wildcard,
"certificate found via wildcard matching",
);
return Some(result);
}
}
}
}
if let Some(ref on_demand) = self.on_demand {
if on_demand.is_allowed(&name) {
self.trigger_on_demand_obtain(on_demand, &name);
debug!(
sni = %name,
"on-demand certificate obtain triggered; returning default cert for now",
);
} else {
debug!(
sni = %name,
"on-demand TLS not allowed for this name",
);
}
}
if let Some(ref fallback) = self.fallback_server_name
&& fallback != &name
{
debug!(
sni = %name,
fallback = %fallback,
"trying fallback_server_name",
);
if let Some(result) = self.try_resolve_from_cache(fallback) {
return Some(result);
}
}
self.try_default_cert()
}
fn try_resolve_from_cache(&self, name: &str) -> Option<Arc<CertifiedKey>> {
if let Some(cert) = self.try_cache_lookup(name) {
match cert_to_certified_key(&cert) {
Ok(ck) => {
debug!(sni = %name, hash = %cert.hash, "certificate found in cache");
return Some(ck);
}
Err(e) => {
warn!(
sni = %name,
hash = %cert.hash,
error = %e,
"failed to convert cached certificate to CertifiedKey",
);
}
}
}
None
}
fn try_resolve_from_cache_with_ocsp_check(&self, name: &str) -> Option<Arc<CertifiedKey>> {
if let Some(cert) = self.try_cache_lookup(name) {
if (cert.ocsp_status.is_none() || cert.ocsp_response.is_none())
&& let Some(ref refresh_fn) = self.ocsp_refresh_func
{
let refresh = Arc::clone(refresh_fn);
let domain = name.to_owned();
debug!(
sni = %name,
hash = %cert.hash,
"OCSP staple missing; spawning background refresh",
);
tokio::spawn(async move {
refresh(domain).await;
});
}
match cert_to_certified_key(&cert) {
Ok(ck) => {
debug!(sni = %name, hash = %cert.hash, "certificate found in cache");
return Some(ck);
}
Err(e) => {
warn!(
sni = %name,
hash = %cert.hash,
error = %e,
"failed to convert cached certificate to CertifiedKey",
);
}
}
}
None
}
fn try_cache_lookup(&self, name: &str) -> Option<Certificate> {
let handle = tokio::runtime::Handle::try_current().ok()?;
tokio::task::block_in_place(|| handle.block_on(self.cache.get_by_name(name)))
}
fn trigger_on_demand_obtain(&self, on_demand: &Arc<OnDemandConfig>, name: &str) {
let pending = self.pending_obtains.clone();
let name_owned = name.to_owned();
let maybe_guard = self.pending_obtains.try_lock();
match maybe_guard {
Ok(mut guard) => {
if guard.contains_key(&name_owned) {
debug!(
sni = %name,
"on-demand obtain already in progress for this domain; waiting",
);
return;
}
let notify = Arc::new(Notify::new());
guard.insert(name_owned.clone(), Arc::clone(¬ify));
let on_demand = Arc::clone(on_demand);
let name_for_task = name_owned.clone();
tokio::spawn(async move {
Self::do_on_demand_obtain(&on_demand, &name_for_task).await;
notify.notify_waiters();
let mut guard = pending.lock().await;
guard.remove(&name_for_task);
});
}
Err(_) => {
let on_demand = Arc::clone(on_demand);
let name_for_task = name_owned;
tokio::spawn(async move {
Self::do_on_demand_obtain(&on_demand, &name_for_task).await;
});
}
}
}
async fn do_on_demand_obtain(on_demand: &OnDemandConfig, name: &str) {
if let Some(ref limiter) = on_demand.rate_limit
&& !limiter.try_allow().await
{
warn!(
sni = %name,
"on-demand certificate obtain rate-limited; skipping",
);
return;
}
if let Some(ref obtain) = on_demand.obtain_func
&& let Err(e) = obtain(name.to_owned()).await
{
warn!(
sni = %name,
error = %e,
"on-demand certificate obtain failed",
);
}
}
fn try_default_cert(&self) -> Option<Arc<CertifiedKey>> {
match self.default_cert.try_read() {
Ok(guard) => guard.clone(),
Err(_) => {
debug!("could not read default cert (lock contended)");
None
}
}
}
}
impl ResolvesServerCert for CertResolver {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let is_acme_tls_alpn = client_hello
.alpn()
.is_some_and(|mut alpn| alpn.any(|proto| proto == b"acme-tls/1"));
if is_acme_tls_alpn {
if let Some(sni) = client_hello.server_name() {
debug!(sni = %sni, "TLS-ALPN-01 challenge request detected");
if let Some(challenge_cert) = self.try_challenge_lookup(sni) {
debug!(sni = %sni, "serving TLS-ALPN-01 challenge certificate");
return Some(challenge_cert);
}
warn!(
sni = %sni,
"TLS-ALPN-01 challenge requested but no challenge cert registered",
);
return None;
}
debug!("TLS-ALPN-01 challenge without SNI; ignoring");
return None;
}
self.resolve_name(client_hello.server_name())
}
}
impl CertResolver {
fn try_challenge_lookup(&self, name: &str) -> Option<Arc<CertifiedKey>> {
match self.acme_challenges.try_read() {
Ok(guard) => guard.get(name).cloned(),
Err(_) => {
debug!("could not read acme_challenges (lock contended)");
None
}
}
}
}
pub fn signing_key_from_der(
key_der: &PrivateKeyDer<'_>,
) -> std::result::Result<Arc<dyn rustls::sign::SigningKey>, rustls::Error> {
#[cfg(feature = "aws-lc-rs")]
{
rustls::crypto::aws_lc_rs::sign::any_supported_type(key_der)
}
#[cfg(all(feature = "ring", not(feature = "aws-lc-rs")))]
{
rustls::crypto::ring::sign::any_supported_type(key_der)
}
}
pub fn cert_to_certified_key(cert: &Certificate) -> Result<Arc<CertifiedKey>> {
let key_der = reconstruct_private_key_der(cert)?;
let signing_key = signing_key_from_der(&key_der).map_err(|e| {
CryptoError::InvalidKey(format!(
"failed to create signing key from private key: {e}"
))
})?;
let mut certified_key = CertifiedKey::new(cert.cert_chain.clone(), signing_key);
if let Some(ref ocsp) = cert.ocsp_response {
certified_key.ocsp = Some(ocsp.clone());
}
Ok(Arc::new(certified_key))
}
fn reconstruct_private_key_der(cert: &Certificate) -> Result<PrivateKeyDer<'static>> {
let raw = cert
.private_key_der
.as_ref()
.ok_or_else(|| CryptoError::InvalidKey("certificate has no private key".into()))?;
if raw.is_empty() {
return Err(CryptoError::InvalidKey("private key bytes are empty".into()).into());
}
let key_der = match cert.private_key_kind {
PrivateKeyKind::Pkcs8 => PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(raw.clone())),
PrivateKeyKind::Pkcs1 => PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(raw.clone())),
PrivateKeyKind::Sec1 => PrivateKeyDer::Sec1(PrivateSec1KeyDer::from(raw.clone())),
PrivateKeyKind::None => {
return Err(
CryptoError::InvalidKey("certificate private key kind is None".into()).into(),
);
}
};
Ok(key_der)
}
#[cfg(test)]
mod tests {
use chrono::{Duration as ChronoDuration, Utc};
use super::*;
use crate::cache::CacheOptions;
fn make_test_cert(names: &[&str], hash: &str) -> Certificate {
let now = Utc::now();
Certificate {
cert_chain: Vec::new(),
private_key_der: None,
private_key_kind: PrivateKeyKind::None,
names: names.iter().map(|n| n.to_string()).collect(),
tags: Vec::new(),
managed: true,
issuer_key: String::new(),
hash: hash.to_string(),
ocsp_response: None,
ocsp_status: None,
not_after: now + ChronoDuration::days(90),
not_before: now - ChronoDuration::days(1),
ari: None,
}
}
#[test]
fn test_on_demand_config_denied_without_gating() {
let config = OnDemandConfig {
decision_func: None,
host_allowlist: None,
rate_limit: None,
obtain_func: None,
};
assert!(!config.is_allowed("example.com"));
}
#[test]
fn test_on_demand_config_allowlist_hit() {
let mut allowlist = HashSet::new();
allowlist.insert("example.com".to_string());
let config = OnDemandConfig {
decision_func: None,
host_allowlist: Some(allowlist),
rate_limit: None,
obtain_func: None,
};
assert!(config.is_allowed("example.com"));
assert!(config.is_allowed("Example.COM")); }
#[test]
fn test_on_demand_config_allowlist_miss() {
let mut allowlist = HashSet::new();
allowlist.insert("example.com".to_string());
let config = OnDemandConfig {
decision_func: None,
host_allowlist: Some(allowlist),
rate_limit: None,
obtain_func: None,
};
assert!(!config.is_allowed("other.com"));
}
#[test]
fn test_on_demand_config_decision_func() {
let config = OnDemandConfig {
decision_func: Some(Arc::new(|name: &str| name.ends_with(".example.com"))),
host_allowlist: None,
rate_limit: None,
obtain_func: None,
};
assert!(config.is_allowed("sub.example.com"));
assert!(!config.is_allowed("other.com"));
}
#[test]
fn test_on_demand_config_decision_func_takes_priority() {
let mut allowlist = HashSet::new();
allowlist.insert("blocked.com".to_string());
let config = OnDemandConfig {
decision_func: Some(Arc::new(|_| false)),
host_allowlist: Some(allowlist),
rate_limit: None,
obtain_func: None,
};
assert!(!config.is_allowed("blocked.com"));
}
#[test]
fn test_cert_resolver_new() {
let cache = CertCache::new(CacheOptions::default());
let resolver = CertResolver::new(cache);
assert!(format!("{:?}", resolver).contains("CertResolver"));
}
#[test]
fn test_cert_resolver_no_sni_returns_none_without_default() {
let cache = CertCache::new(CacheOptions::default());
let resolver = CertResolver::new(cache);
let result = resolver.resolve_name(None);
assert!(result.is_none());
}
#[test]
fn test_cert_resolver_empty_sni_returns_none_without_default() {
let cache = CertCache::new(CacheOptions::default());
let resolver = CertResolver::new(cache);
let result = resolver.resolve_name(Some(""));
assert!(result.is_none());
}
#[test]
fn test_reconstruct_private_key_no_key() {
let cert = make_test_cert(&["example.com"], "h1");
let result = reconstruct_private_key_der(&cert);
assert!(result.is_err());
}
#[test]
fn test_reconstruct_private_key_empty_bytes() {
let mut cert = make_test_cert(&["example.com"], "h1");
cert.private_key_der = Some(Vec::new());
cert.private_key_kind = PrivateKeyKind::Pkcs8;
let result = reconstruct_private_key_der(&cert);
assert!(result.is_err());
}
#[test]
fn test_reconstruct_private_key_none_kind() {
let mut cert = make_test_cert(&["example.com"], "h1");
cert.private_key_der = Some(vec![1, 2, 3]);
cert.private_key_kind = PrivateKeyKind::None;
let result = reconstruct_private_key_der(&cert);
assert!(result.is_err());
}
#[test]
fn test_reconstruct_private_key_pkcs8() {
let mut cert = make_test_cert(&["example.com"], "h1");
cert.private_key_der = Some(vec![1, 2, 3, 4]);
cert.private_key_kind = PrivateKeyKind::Pkcs8;
let result = reconstruct_private_key_der(&cert);
assert!(result.is_ok());
}
#[test]
fn test_reconstruct_private_key_pkcs1() {
let mut cert = make_test_cert(&["example.com"], "h1");
cert.private_key_der = Some(vec![1, 2, 3, 4]);
cert.private_key_kind = PrivateKeyKind::Pkcs1;
let result = reconstruct_private_key_der(&cert);
assert!(result.is_ok());
}
#[test]
fn test_reconstruct_private_key_sec1() {
let mut cert = make_test_cert(&["example.com"], "h1");
cert.private_key_der = Some(vec![1, 2, 3, 4]);
cert.private_key_kind = PrivateKeyKind::Sec1;
let result = reconstruct_private_key_der(&cert);
assert!(result.is_ok());
}
#[test]
fn test_on_demand_config_debug() {
let config = OnDemandConfig {
decision_func: None,
host_allowlist: None,
rate_limit: None,
obtain_func: None,
};
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("OnDemandConfig"));
}
#[test]
fn test_cert_resolver_debug() {
let cache = CertCache::new(CacheOptions::default());
let resolver = CertResolver::new(cache);
let debug_str = format!("{:?}", resolver);
assert!(debug_str.contains("CertResolver"));
}
}