use std::{
sync::Arc,
time::{Duration, Instant},
};
use bytes::Bytes;
use openssl::{
asn1::Asn1Time,
pkey::{PKey, Private},
x509::X509,
};
use crate::{
acme::{Account, ChallengeRegistrations},
error::Error,
tls::{self, Identity, TlsAcceptor},
};
const RENEW_TIMEOUT: Duration = Duration::from_secs(60);
pub struct Watchdog {
key: Arc<PKey<Private>>,
key_pem: Vec<u8>,
account: Account,
challenges: ChallengeRegistrations,
tls_acceptor: TlsAcceptor,
hostname: String,
}
impl Watchdog {
pub async fn new<T>(
account: Account,
challenges: ChallengeRegistrations,
tls_acceptor: TlsAcceptor,
hostname: T,
) -> Result<Self, Error>
where
T: ToString,
{
let key: Arc<PKey<Private>> = tokio::task::spawn_blocking(tls::generate_tls_key)
.await
.map_err(|_| Error::from_static_msg("terminating"))??
.into();
let key_pem = key.private_key_to_pem_pkcs8()?;
let res = Self {
key,
key_pem,
account,
challenges,
tls_acceptor,
hostname: hostname.to_string(),
};
Ok(res)
}
pub async fn watch(self) {
loop {
info!("renewing TLS certificate");
let res = self.renew_certificate().await;
if let Err(err) = &res {
warn!("unable to renew TLS certificate: {err}");
} else {
info!("TLS certificate renewed");
}
let delay = res
.map(|expires_in| expires_in * 2 / 3)
.unwrap_or_else(|_| Duration::from_secs(10));
info!("next TLS certificate renew in {:?}", delay);
let start = Instant::now();
loop {
let elapsed = start.elapsed();
if delay <= elapsed {
break;
}
let remaining = delay - elapsed;
tokio::time::sleep(remaining.min(Duration::from_secs(3_600))).await;
}
}
}
async fn renew_certificate(&self) -> Result<Duration, Error> {
let chain = tokio::time::timeout(RENEW_TIMEOUT, self.get_new_certificate())
.await
.map_err(|_| Error::from_static_msg("renew timeout"))??;
let identity = Identity::from_pkcs8(&chain, &self.key_pem)?;
self.tls_acceptor.set_identity(identity)?;
let stack = X509::stack_from_pem(&chain)?;
let cert = stack
.first()
.ok_or_else(|| Error::from_static_msg("no certificate in the chain"))?;
let expires_in = Asn1Time::days_from_now(0)
.and_then(|now| now.diff(cert.not_after()))
.map(|diff| (diff.days as i64) * 86_400 + (diff.secs as i64))?;
if expires_in > 0 {
Ok(Duration::from_secs(expires_in as u64))
} else {
Err(Error::from_static_msg("received expired certificate"))
}
}
async fn get_new_certificate(&self) -> Result<Bytes, Error> {
let order = self
.account
.new_order(&self.hostname)
.await
.map_err(|err| {
Error::from_static_msg_and_cause("unable to create a new certificate order", err)
})?;
let challenge = order.challenge();
let authorized = self.challenges.register(challenge);
let guard = ChallengeRegistrationGuard {
challenges: &self.challenges,
token: challenge.token(),
};
let key = self.key.clone();
let hostname = self.hostname.clone();
let blocking = tokio::task::spawn_blocking(move || tls::create_csr(&key, &hostname));
let csr = blocking
.await
.map_err(|_| Error::from_static_msg("terminating"))??;
let cert = self
.account
.close_order(&order, authorized, &csr)
.await
.map_err(|err| {
Error::from_static_msg_and_cause("unable to close a certificate order", err)
})?;
std::mem::drop(guard);
Ok(cert)
}
}
struct ChallengeRegistrationGuard<'a> {
challenges: &'a ChallengeRegistrations,
token: &'a str,
}
impl<'a> Drop for ChallengeRegistrationGuard<'a> {
fn drop(&mut self) {
self.challenges.deregister(self.token);
}
}