use std::sync::Arc;
use std::time::Duration;
use tokio::time::{interval, Instant};
use tracing::{debug, error, info, warn};
use zentinel_config::server::AcmeChallengeType;
use super::challenge::ChallengeManager;
use super::client::AcmeClient;
use super::dns::Dns01ChallengeManager;
use super::error::AcmeError;
use crate::tls::HotReloadableSniResolver;
const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(12 * 3600);
const MIN_CHECK_INTERVAL: Duration = Duration::from_secs(3600);
pub struct RenewalScheduler {
client: Arc<AcmeClient>,
challenge_manager: Arc<ChallengeManager>,
dns_challenge_manager: Option<Arc<Dns01ChallengeManager>>,
sni_resolver: Option<Arc<HotReloadableSniResolver>>,
check_interval: Duration,
}
impl RenewalScheduler {
pub fn new(
client: Arc<AcmeClient>,
challenge_manager: Arc<ChallengeManager>,
sni_resolver: Option<Arc<HotReloadableSniResolver>>,
) -> Self {
Self {
client,
challenge_manager,
dns_challenge_manager: None,
sni_resolver,
check_interval: DEFAULT_CHECK_INTERVAL,
}
}
pub fn with_dns_manager(mut self, dns_manager: Arc<Dns01ChallengeManager>) -> Self {
self.dns_challenge_manager = Some(dns_manager);
self
}
pub fn with_interval(mut self, interval: Duration) -> Self {
self.check_interval = interval.max(MIN_CHECK_INTERVAL);
self
}
fn challenge_type(&self) -> AcmeChallengeType {
self.client.config().challenge_type
}
pub async fn run(self) {
info!(
check_interval_hours = self.check_interval.as_secs() / 3600,
"Starting certificate renewal scheduler"
);
tokio::time::sleep(Duration::from_secs(10)).await;
if let Err(e) = self.check_renewals().await {
error!(error = %e, "Initial certificate renewal check failed");
}
let mut interval = interval(self.check_interval);
loop {
interval.tick().await;
debug!("Running scheduled certificate renewal check");
if let Err(e) = self.check_renewals().await {
error!(error = %e, "Certificate renewal check failed");
}
}
}
async fn check_renewals(&self) -> Result<(), AcmeError> {
let domains = self.client.config().domains.clone();
if domains.is_empty() {
return Ok(());
}
let domain = &domains[0];
match self.client.needs_renewal(domain) {
Ok(true) => {
info!(domain = %domain, "Certificate needs renewal");
match self.renew_certificate().await {
Ok(()) => {
info!(domain = %domain, "Certificate renewed successfully");
if let Some(ref resolver) = self.sni_resolver {
if let Err(e) = resolver.reload() {
error!(
domain = %domain,
error = %e,
"Failed to reload TLS configuration"
);
} else {
info!("TLS configuration reloaded with new certificate");
}
}
}
Err(e) => {
error!(
domain = %domain,
error = %e,
"Certificate renewal failed"
);
return Err(e);
}
}
}
Ok(false) => {
debug!(domain = %domain, "Certificate is still valid");
}
Err(e) => {
warn!(
domain = %domain,
error = %e,
"Failed to check certificate renewal status"
);
}
}
Ok(())
}
async fn renew_certificate(&self) -> Result<(), AcmeError> {
match self.challenge_type() {
AcmeChallengeType::Http01 => self.renew_certificate_http01().await,
AcmeChallengeType::Dns01 => self.renew_certificate_dns01().await,
}
}
async fn renew_certificate_http01(&self) -> Result<(), AcmeError> {
let start = Instant::now();
info!("Starting certificate renewal with HTTP-01 challenge");
let (mut order, challenges) = self.client.create_order().await?;
for challenge in &challenges {
self.challenge_manager
.add_challenge(&challenge.token, &challenge.key_authorization);
}
for challenge in &challenges {
self.client
.validate_challenge(&mut order, &challenge.url)
.await?;
}
self.client.wait_for_order_ready(&mut order).await?;
for challenge in &challenges {
self.challenge_manager.remove_challenge(&challenge.token);
}
let (cert_pem, key_pem, expires) = self.client.finalize_order(&mut order).await?;
self.save_certificate(&cert_pem, &key_pem, expires)?;
let elapsed = start.elapsed();
info!(
elapsed_secs = elapsed.as_secs(),
expires = %expires,
"Certificate renewal completed (HTTP-01)"
);
Ok(())
}
async fn renew_certificate_dns01(&self) -> Result<(), AcmeError> {
let dns_manager = self
.dns_challenge_manager
.as_ref()
.ok_or(AcmeError::NoDnsProvider)?;
let start = Instant::now();
info!(
provider = %dns_manager.provider_name(),
"Starting certificate renewal with DNS-01 challenge"
);
let (mut order, mut challenges) = self.client.create_order_dns01().await?;
for challenge in &mut challenges {
if let Err(e) = dns_manager.create_and_wait(challenge).await {
warn!(
domain = %challenge.domain,
error = %e,
"Failed to create DNS record, cleaning up"
);
dns_manager.cleanup_all(&challenges).await;
return Err(e.into());
}
}
for challenge in &challenges {
if let Err(e) = self
.client
.validate_challenge(&mut order, &challenge.url)
.await
{
dns_manager.cleanup_all(&challenges).await;
return Err(e);
}
}
let validation_result = self.client.wait_for_order_ready(&mut order).await;
dns_manager.cleanup_all(&challenges).await;
validation_result?;
let (cert_pem, key_pem, expires) = self.client.finalize_order(&mut order).await?;
self.save_certificate(&cert_pem, &key_pem, expires)?;
let elapsed = start.elapsed();
info!(
elapsed_secs = elapsed.as_secs(),
expires = %expires,
"Certificate renewal completed (DNS-01)"
);
Ok(())
}
fn save_certificate(
&self,
cert_pem: &str,
key_pem: &str,
expires: chrono::DateTime<chrono::Utc>,
) -> Result<(), AcmeError> {
let primary_domain = self
.client
.config()
.domains
.first()
.ok_or_else(|| AcmeError::OrderCreation("No domains configured".to_string()))?;
self.client.storage().save_certificate(
primary_domain,
cert_pem,
key_pem,
expires,
&self.client.config().domains,
)?;
Ok(())
}
pub async fn ensure_certificates(&self) -> Result<(), AcmeError> {
let domains = self.client.config().domains.clone();
if domains.is_empty() {
return Err(AcmeError::OrderCreation(
"No domains configured".to_string(),
));
}
let primary_domain = &domains[0];
if self.client.needs_renewal(primary_domain)? {
info!(
domain = %primary_domain,
"Initial certificate issuance required"
);
self.renew_certificate().await?;
} else {
info!(
domain = %primary_domain,
"Certificate already exists and is valid"
);
}
Ok(())
}
}
impl std::fmt::Debug for RenewalScheduler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RenewalScheduler")
.field("check_interval", &self.check_interval)
.field("has_sni_resolver", &self.sni_resolver.is_some())
.finish()
}
}