use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn};
use crate::async_jobs::JobQueue;
use crate::cache::CertCache;
use crate::certificates::{Certificate, DEFAULT_RENEWAL_WINDOW_RATIO};
use crate::error::Result;
use crate::ocsp::{self, OcspConfig};
use crate::storage::Storage;
pub const DEFAULT_RENEW_CHECK_INTERVAL: Duration = Duration::from_secs(10 * 60);
pub const DEFAULT_OCSP_CHECK_INTERVAL: Duration = Duration::from_secs(60 * 60);
#[derive(Clone)]
pub struct MaintenanceConfig {
pub renew_check_interval: Duration,
pub ocsp_check_interval: Duration,
pub ocsp: OcspConfig,
pub storage: Arc<dyn Storage>,
}
impl std::fmt::Debug for MaintenanceConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MaintenanceConfig")
.field("renew_check_interval", &self.renew_check_interval)
.field("ocsp_check_interval", &self.ocsp_check_interval)
.field("ocsp", &self.ocsp)
.finish()
}
}
pub type RenewFn = dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync;
pub fn start_maintenance(
cache: Arc<CertCache>,
config: MaintenanceConfig,
renew_func: Arc<RenewFn>,
) -> JoinHandle<()> {
let renew_interval =
normalize_interval(config.renew_check_interval, DEFAULT_RENEW_CHECK_INTERVAL);
let ocsp_interval = normalize_interval(config.ocsp_check_interval, DEFAULT_OCSP_CHECK_INTERVAL);
info!(
renew_check_secs = renew_interval.as_secs(),
ocsp_check_secs = ocsp_interval.as_secs(),
"starting certificate maintenance loop",
);
tokio::spawn(async move {
maintenance_loop_with_recovery(cache, config, renew_func, renew_interval, ocsp_interval)
.await;
})
}
async fn maintenance_loop_with_recovery(
cache: Arc<CertCache>,
config: MaintenanceConfig,
renew_func: Arc<RenewFn>,
renew_interval: Duration,
ocsp_interval: Duration,
) {
loop {
let cache_clone = Arc::clone(&cache);
let config_clone = config.clone();
let renew_func_clone = Arc::clone(&renew_func);
let result = tokio::task::spawn(maintenance_loop(
cache_clone,
config_clone,
renew_func_clone,
renew_interval,
ocsp_interval,
))
.await;
match result {
Ok(()) => {
info!("maintenance: loop exited normally");
break;
}
Err(join_error) => {
if join_error.is_panic() {
error!(
"maintenance: loop panicked, restarting after brief delay: {}",
join_error,
);
tokio::time::sleep(Duration::from_secs(2)).await;
let mut stop_rx = cache.subscribe_stop();
if *stop_rx.borrow_and_update() {
info!("maintenance: stop signal received during panic recovery, exiting");
break;
}
} else {
error!("maintenance: loop task cancelled: {}", join_error,);
break;
}
}
}
}
cache.signal_done();
}
async fn maintenance_loop(
cache: Arc<CertCache>,
config: MaintenanceConfig,
renew_func: Arc<RenewFn>,
renew_interval: Duration,
ocsp_interval: Duration,
) {
let mut stop_rx = cache.subscribe_stop();
let job_queue = JobQueue::new("renewal");
let mut renew_ticker = tokio::time::interval(renew_interval);
let mut ocsp_ticker = tokio::time::interval(ocsp_interval);
renew_ticker.tick().await;
ocsp_ticker.tick().await;
loop {
tokio::select! {
_ = renew_ticker.tick() => {
debug!("maintenance: running renewal check");
check_renewals(
&cache,
DEFAULT_RENEWAL_WINDOW_RATIO,
renew_func.as_ref(),
config.storage.as_ref(),
&job_queue,
).await;
}
_ = ocsp_ticker.tick() => {
debug!("maintenance: running OCSP staple check");
update_ocsp_staples(
&cache,
config.storage.as_ref(),
&config.ocsp,
renew_func.as_ref(),
).await;
}
result = stop_rx.changed() => {
match result {
Ok(()) => {
if *stop_rx.borrow() {
info!("maintenance: received stop signal, exiting");
break;
}
}
Err(_) => {
info!("maintenance: cache dropped, exiting");
break;
}
}
}
}
}
}
async fn check_renewals(
cache: &CertCache,
renewal_ratio: f64,
renew_func: &RenewFn,
storage: &dyn Storage,
job_queue: &JobQueue,
) {
let managed_certs = cache.get_managed_certificates().await;
if managed_certs.is_empty() {
debug!("maintenance: no managed certificates to check for renewal");
return;
}
debug!(
count = managed_certs.len(),
"maintenance: checking managed certificates for renewal",
);
for cert in &managed_certs {
if !cert.needs_renewal(renewal_ratio) {
continue;
}
let domain = match cert.names.first() {
Some(name) => name.clone(),
None => {
warn!(
hash = %cert.hash,
"maintenance: managed certificate has no names; skipping renewal"
);
continue;
}
};
if try_reload_from_storage(cache, storage, cert, &domain, renewal_ratio).await {
continue;
}
info!(
domain = %domain,
hash = %cert.hash,
expired = cert.expired(),
"maintenance: certificate needs renewal",
);
let job_name = format!("renew_{}", domain);
let domain_clone = domain.clone();
let renew_result = renew_func(domain_clone.clone());
job_queue
.submit(job_name, move || async move {
match renew_result.await {
Ok(()) => {
info!(
domain = %domain_clone,
"maintenance: certificate renewed successfully",
);
}
Err(e) => {
error!(
domain = %domain_clone,
error = %e,
"maintenance: certificate renewal failed",
);
}
}
})
.await;
}
}
async fn try_reload_from_storage(
cache: &CertCache,
storage: &dyn Storage,
cached_cert: &Certificate,
domain: &str,
renewal_ratio: f64,
) -> bool {
let cert_key = crate::storage::site_cert_key(&cached_cert.issuer_key, domain);
let cert_pem = match storage.load(&cert_key).await {
Ok(bytes) => bytes,
Err(_) => return false,
};
let key_key = crate::storage::site_private_key(&cached_cert.issuer_key, domain);
let key_pem = match storage.load(&key_key).await {
Ok(bytes) => bytes,
Err(_) => return false,
};
let stored_cert = match Certificate::from_pem(&cert_pem, &key_pem) {
Ok(mut c) => {
c.managed = true;
c.issuer_key = cached_cert.issuer_key.clone();
c
}
Err(_) => return false,
};
if stored_cert.hash != cached_cert.hash && !stored_cert.needs_renewal(renewal_ratio) {
info!(
domain = %domain,
old_hash = %cached_cert.hash,
new_hash = %stored_cert.hash,
"maintenance: certificate already renewed by another instance; reloading from storage",
);
cache.replace(&cached_cert.hash, stored_cert).await;
return true;
}
false
}
async fn update_ocsp_staples(
cache: &CertCache,
storage: &dyn Storage,
config: &OcspConfig,
renew_func: &RenewFn,
) {
if config.disable_stapling {
debug!("maintenance: OCSP stapling is disabled; skipping update");
return;
}
let all_certs = cache.get_all().await;
if all_certs.is_empty() {
debug!("maintenance: no certificates to check for OCSP updates");
return;
}
debug!(
count = all_certs.len(),
"maintenance: checking certificates for OCSP staple freshness",
);
for cert in all_certs {
let needs_update = needs_ocsp_refresh(&cert);
if !needs_update {
continue;
}
let first_name = cert.names.first().cloned().unwrap_or_default();
let old_hash = cert.hash.clone();
debug!(
name = %first_name,
hash = %old_hash,
"maintenance: refreshing OCSP staple",
);
let mut updated_cert = cert.clone();
match ocsp::staple_ocsp(storage, &mut updated_cert, config).await {
Ok(not_revoked) => {
if not_revoked {
debug!(
name = %first_name,
"maintenance: OCSP staple refreshed successfully",
);
cache.replace(&old_hash, updated_cert).await;
} else {
if updated_cert.ocsp_response.is_some()
&& config.replace_revoked
&& cert.managed
{
warn!(
name = %first_name,
hash = %old_hash,
"maintenance: OCSP status is Revoked; triggering force-renewal",
);
match renew_func(first_name.clone()).await {
Ok(()) => {
info!(
name = %first_name,
"maintenance: revoked certificate force-renewed successfully",
);
}
Err(e) => {
error!(
name = %first_name,
error = %e,
"maintenance: revoked certificate force-renewal failed",
);
}
}
} else {
debug!(
name = %first_name,
"maintenance: OCSP stapling not available for this certificate",
);
}
}
}
Err(e) => {
warn!(
name = %first_name,
error = %e,
"maintenance: failed to refresh OCSP staple",
);
}
}
}
}
fn needs_ocsp_refresh(cert: &crate::certificates::Certificate) -> bool {
match &cert.ocsp_response {
None => {
true
}
Some(raw) => {
match try_parse_ocsp_for_freshness(raw) {
Some(parsed) => ocsp::ocsp_needs_update(&parsed),
None => true,
}
}
}
}
fn try_parse_ocsp_for_freshness(raw: &[u8]) -> Option<ocsp::OcspResponse> {
let _ = raw;
None
}
#[derive(Debug, Clone)]
pub struct CleanStorageOptions {
pub expired_cert_grace_period: Duration,
pub ocsp_max_age: Duration,
}
impl Default for CleanStorageOptions {
fn default() -> Self {
Self {
expired_cert_grace_period: Duration::from_secs(24 * 60 * 60), ocsp_max_age: Duration::from_secs(14 * 24 * 60 * 60), }
}
}
#[derive(Debug, Clone, Default)]
pub struct CleanStorageResult {
pub deleted_certs: usize,
pub deleted_ocsp: usize,
}
pub async fn clean_storage(
storage: &dyn Storage,
options: &CleanStorageOptions,
) -> Result<CleanStorageResult> {
let mut result = CleanStorageResult::default();
let certs_entries = match storage.list("certificates", true).await {
Ok(entries) => entries,
Err(e) => {
warn!(error = %e, "clean_storage: failed to list certificates");
Vec::new()
}
};
let now = chrono::Utc::now();
let grace_period = chrono::Duration::from_std(options.expired_cert_grace_period)
.unwrap_or_else(|_| chrono::Duration::hours(24));
for entry in &certs_entries {
if !entry.ends_with(".crt") {
continue;
}
let cert_pem = match storage.load(entry).await {
Ok(data) => data,
Err(_) => continue,
};
let cert_pem_str = match std::str::from_utf8(&cert_pem) {
Ok(s) => s,
Err(_) => continue,
};
let pems = match pem::parse_many(cert_pem_str) {
Ok(p) => p,
Err(_) => continue,
};
let leaf_der = match pems.iter().find(|p| p.tag() == "CERTIFICATE") {
Some(p) => p.contents(),
None => continue,
};
let not_after = match x509_parser::parse_x509_certificate(leaf_der) {
Ok((_, cert)) => {
let epoch = cert.validity().not_after.timestamp();
match chrono::DateTime::<chrono::Utc>::from_timestamp(epoch, 0) {
Some(dt) => dt,
None => continue,
}
}
Err(_) => continue,
};
if now > not_after + grace_period {
let base = entry.trim_end_matches(".crt");
let key_path = format!("{base}.key");
let meta_path = format!("{base}.json");
for path in [entry.as_str(), key_path.as_str(), meta_path.as_str()] {
if let Err(e) = storage.delete(path).await {
warn!(
path = path,
error = %e,
"clean_storage: failed to delete expired certificate asset"
);
}
}
info!(
cert_path = entry.as_str(),
not_after = %not_after,
"clean_storage: deleted expired certificate",
);
result.deleted_certs += 1;
}
}
let ocsp_entries = match storage.list("ocsp", true).await {
Ok(entries) => entries,
Err(e) => {
warn!(error = %e, "clean_storage: failed to list OCSP staples");
Vec::new()
}
};
let max_age = chrono::Duration::from_std(options.ocsp_max_age)
.unwrap_or_else(|_| chrono::Duration::days(14));
for entry in &ocsp_entries {
let stat = match storage.stat(entry).await {
Ok(s) => s,
Err(_) => continue,
};
if now > stat.modified + max_age {
if let Err(e) = storage.delete(entry).await {
warn!(
path = entry.as_str(),
error = %e,
"clean_storage: failed to delete stale OCSP staple"
);
} else {
debug!(
path = entry.as_str(),
modified = %stat.modified,
"clean_storage: deleted stale OCSP staple",
);
result.deleted_ocsp += 1;
}
}
}
Ok(result)
}
fn normalize_interval(interval: Duration, default: Duration) -> Duration {
if interval.is_zero() {
default
} else {
interval
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use chrono::{Duration as ChronoDuration, Utc};
use super::*;
use crate::cache::CacheOptions;
use crate::certificates::{Certificate, PrivateKeyKind};
fn make_managed_cert(names: &[&str], hash: &str, days_remaining: i64) -> Certificate {
let now = Utc::now();
Certificate {
cert_chain: Vec::new(),
private_key_der: None,
private_key_kind: PrivateKeyKind::None,
names: names.iter().map(|n| n.to_string()).collect(),
tags: Vec::new(),
managed: true,
issuer_key: String::new(),
hash: hash.to_string(),
ocsp_response: None,
ocsp_status: None,
not_after: now + ChronoDuration::days(days_remaining),
not_before: now - ChronoDuration::days(90 - days_remaining),
ari: None,
}
}
#[test]
fn test_normalize_interval_zero() {
let result = normalize_interval(Duration::ZERO, DEFAULT_RENEW_CHECK_INTERVAL);
assert_eq!(result, DEFAULT_RENEW_CHECK_INTERVAL);
}
#[test]
fn test_normalize_interval_custom() {
let custom = Duration::from_secs(42);
let result = normalize_interval(custom, DEFAULT_RENEW_CHECK_INTERVAL);
assert_eq!(result, custom);
}
#[test]
fn test_needs_ocsp_refresh_no_response() {
let cert = make_managed_cert(&["example.com"], "h1", 60);
assert!(needs_ocsp_refresh(&cert));
}
#[test]
fn test_needs_ocsp_refresh_with_response() {
let mut cert = make_managed_cert(&["example.com"], "h1", 60);
cert.ocsp_response = Some(vec![1, 2, 3]); assert!(needs_ocsp_refresh(&cert));
}
#[tokio::test]
async fn test_check_renewals_no_certs() {
let cache = CertCache::new(CacheOptions::default());
let call_count = Arc::new(AtomicUsize::new(0));
let counter = Arc::clone(&call_count);
let renew_func: Box<RenewFn> = Box::new(move |_domain: String| {
let counter = Arc::clone(&counter);
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
})
});
let job_queue = JobQueue::new("test");
check_renewals(
&cache,
DEFAULT_RENEWAL_WINDOW_RATIO,
renew_func.as_ref(),
&DummyStorage,
&job_queue,
)
.await;
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(call_count.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_check_renewals_fresh_cert() {
let cache = CertCache::new(CacheOptions::default());
let cert = make_managed_cert(&["example.com"], "h1", 60);
cache.add(cert).await;
let call_count = Arc::new(AtomicUsize::new(0));
let counter = Arc::clone(&call_count);
let renew_func: Box<RenewFn> = Box::new(move |_domain: String| {
let counter = Arc::clone(&counter);
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
})
});
let job_queue = JobQueue::new("test");
check_renewals(
&cache,
DEFAULT_RENEWAL_WINDOW_RATIO,
renew_func.as_ref(),
&DummyStorage,
&job_queue,
)
.await;
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(call_count.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_check_renewals_due_cert() {
let cache = CertCache::new(CacheOptions::default());
let cert = make_managed_cert(&["renew-me.example.com"], "h2", 10);
cache.add(cert).await;
let call_count = Arc::new(AtomicUsize::new(0));
let counter = Arc::clone(&call_count);
let renew_func: Box<RenewFn> = Box::new(move |_domain: String| {
let counter = Arc::clone(&counter);
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
})
});
let job_queue = JobQueue::new("test");
check_renewals(
&cache,
DEFAULT_RENEWAL_WINDOW_RATIO,
renew_func.as_ref(),
&DummyStorage,
&job_queue,
)
.await;
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_check_renewals_expired_cert() {
let cache = CertCache::new(CacheOptions::default());
let cert = make_managed_cert(&["expired.example.com"], "h3", -1);
cache.add(cert).await;
let call_count = Arc::new(AtomicUsize::new(0));
let counter = Arc::clone(&call_count);
let renew_func: Box<RenewFn> = Box::new(move |_domain: String| {
let counter = Arc::clone(&counter);
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
})
});
let job_queue = JobQueue::new("test");
check_renewals(
&cache,
DEFAULT_RENEWAL_WINDOW_RATIO,
renew_func.as_ref(),
&DummyStorage,
&job_queue,
)
.await;
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_maintenance_stops_on_signal() {
use crate::file_storage::FileStorage;
let cache = CertCache::new(CacheOptions::default());
let temp_dir = tempfile::tempdir().unwrap();
let storage: Arc<dyn Storage> = Arc::new(FileStorage::new(temp_dir.path().to_path_buf()));
let config = MaintenanceConfig {
renew_check_interval: Duration::from_millis(50),
ocsp_check_interval: Duration::from_millis(50),
ocsp: OcspConfig::default(),
storage,
};
let renew_func: Arc<RenewFn> = Arc::new(|_domain: String| Box::pin(async { Ok(()) }));
let handle = start_maintenance(Arc::clone(&cache), config, renew_func);
tokio::time::sleep(Duration::from_millis(120)).await;
cache.stop();
let result = tokio::time::timeout(Duration::from_secs(2), handle).await;
assert!(result.is_ok(), "maintenance task did not stop in time");
}
#[test]
fn test_maintenance_config_debug() {
let config = MaintenanceConfig {
renew_check_interval: DEFAULT_RENEW_CHECK_INTERVAL,
ocsp_check_interval: DEFAULT_OCSP_CHECK_INTERVAL,
ocsp: OcspConfig::default(),
storage: Arc::new(DummyStorage),
};
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("MaintenanceConfig"));
}
struct DummyStorage;
#[async_trait::async_trait]
impl Storage for DummyStorage {
async fn store(&self, _key: &str, _value: &[u8]) -> Result<()> {
Ok(())
}
async fn load(&self, _key: &str) -> Result<Vec<u8>> {
Err(crate::error::StorageError::NotFound("not found".into()).into())
}
async fn delete(&self, _key: &str) -> Result<()> {
Ok(())
}
async fn exists(&self, _key: &str) -> Result<bool> {
Ok(false)
}
async fn list(&self, _path: &str, _recursive: bool) -> Result<Vec<String>> {
Ok(Vec::new())
}
async fn stat(&self, _key: &str) -> Result<crate::storage::KeyInfo> {
Err(crate::error::StorageError::NotFound("not found".into()).into())
}
async fn lock(&self, _name: &str) -> Result<()> {
Ok(())
}
async fn unlock(&self, _name: &str) -> Result<()> {
Ok(())
}
}
}