pub mod cloudflare;
use std::{
path::PathBuf,
str::FromStr,
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use anyhow::{Context, Error, anyhow};
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use core::fmt;
use derive_new::new;
use fqdn::FQDN;
use hickory_proto::rr::RecordType;
use ic_bn_lib_common::{
traits::{
Run,
acme::{AcmeCertificateClient, DnsManager, TokenManager},
dns::Resolves,
},
types::acme::{AcmeUrl, Record},
};
use instant_acme::AccountCredentials;
use rustls::{
server::{ClientHello, ResolvesServerCert},
sign::CertifiedKey,
};
use strum_macros::{Display, EnumString};
use tokio::fs;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use x509_parser::prelude::{FromDer, X509Certificate};
use crate::{
RetryError, retry_async,
tls::{
acme::client::{Client, ClientBuilder},
extract_sans, pem_convert_to_rustls_single, sni_matches,
},
};
const ACME_RECORD: &str = "_acme-challenge";
const FILE_CERT: &str = "cert.pem";
const TTL: u32 = 60;
#[derive(new)]
pub struct TokenManagerDns {
resolver: Arc<dyn Resolves>,
manager: Arc<dyn DnsManager>,
delegation_domain: Option<String>,
}
impl TokenManagerDns {
fn pick_zone(&self, zone: &str) -> String {
self.delegation_domain
.as_ref()
.map_or_else(|| zone.to_string(), |v| v.clone())
}
fn pick_record(&self, zone: &str) -> String {
self.delegation_domain.as_ref().map_or_else(
|| ACME_RECORD.to_string(),
|_| format!("{ACME_RECORD}.{zone}"),
)
}
}
#[async_trait]
impl TokenManager for TokenManagerDns {
async fn verify(&self, zone: &str, token: &str) -> Result<(), Error> {
let host = format!("{}.{}", self.pick_record(zone), self.pick_zone(zone));
retry_async! {
async {
self.resolver.flush_cache();
let records = self
.resolver
.resolve(RecordType::TXT, &host)
.await
.map_err(|e| RetryError::Transient(e.into()))?;
records
.iter()
.find(|&x| x.record_type() == RecordType::TXT && x.data().to_string() == token)
.ok_or_else(|| RetryError::Transient(anyhow!("requested record not found")))?;
Ok(())
}, Duration::from_secs(2 * TTL as u64)}
}
async fn set(&self, zone: &str, token: &str) -> Result<(), Error> {
self.manager
.create(
&self.pick_zone(zone),
&self.pick_record(zone),
Record::Txt(token.into()),
TTL,
)
.await
}
async fn unset(&self, zone: &str) -> Result<(), Error> {
self.manager
.delete(&self.pick_zone(zone), &self.pick_record(zone))
.await
}
}
#[derive(Debug, Clone, Display, EnumString, PartialEq, Eq)]
pub enum Validity {
Missing,
Expires,
SANMismatch,
Valid,
}
#[derive(Debug, Clone, Display, EnumString, PartialEq, Eq)]
pub enum RefreshResult {
StillValid,
Refreshed,
}
pub struct AcmeDns {
client: Arc<Client>,
path: PathBuf,
domains: Vec<FQDN>,
names: Vec<String>,
wildcard: bool,
renew_before: Duration,
cert: ArcSwapOption<CertifiedKey>,
}
pub struct Opts {
pub acme_url: AcmeUrl,
pub path: PathBuf,
pub domains: Vec<FQDN>,
pub wildcard: bool,
pub renew_before: Duration,
pub account_credentials: Option<AccountCredentials>,
pub token_manager: Arc<dyn TokenManager>,
pub insecure_tls: bool,
}
impl AcmeDns {
pub async fn new(opts: Opts) -> Result<Self, Error> {
let mut builder = ClientBuilder::new(opts.insecure_tls)
.with_acme_url(opts.acme_url)
.with_token_manager(opts.token_manager);
let account_path = opts.path.join("account.json");
let mut names = opts
.domains
.clone()
.into_iter()
.flat_map(|x| {
let x = x.to_string();
let mut out = vec![x.clone()];
if opts.wildcard {
out.push(format!("*.{x}"));
}
out.into_iter()
})
.collect::<Vec<_>>();
names.sort();
if let Some(v) = opts.account_credentials {
builder = builder
.load_account(v)
.await
.context("unable to load ACME account")?;
} else if let Ok(v) = fs::read(&account_path).await {
let creds: AccountCredentials =
serde_json::from_slice(&v).context("unable to parse ACME account credentials")?;
builder = builder
.load_account(creds)
.await
.context("unable to load ACME account")?;
} else {
let (builder2, creds) = builder
.create_account("mailto:boundary-nodes@dfinity.org")
.await
.context("unable to create ACME account")?;
builder = builder2;
let creds = serde_json::to_vec_pretty(&creds)
.context("unable to serialize ACME credentials to JSON")?;
fs::write(&account_path, creds)
.await
.context("unable to save ACME credentials to file")?;
}
let client = Arc::new(
builder
.build()
.await
.context("unable to create ACME client")?,
);
Ok(Self {
client,
path: opts.path,
domains: opts.domains,
names,
wildcard: opts.wildcard,
renew_before: opts.renew_before,
cert: ArcSwapOption::empty(),
})
}
async fn load(&self) -> Result<(), Error> {
let cert_and_key = fs::read(self.path.join(FILE_CERT))
.await
.context("unable to read cert")?;
let ckey = pem_convert_to_rustls_single(&cert_and_key)
.context("unable to convert certificate to Rustls format")?;
self.cert.store(Some(Arc::new(ckey)));
Ok(())
}
pub async fn is_valid(&self) -> Result<Validity, Error> {
let Some(ckey) = self.cert.load_full() else {
return Ok(Validity::Missing);
};
if ckey.cert.is_empty() {
return Ok(Validity::Missing);
}
let cert = X509Certificate::from_der(ckey.cert[0].as_ref())
.context("Unable to parse DER-encoded certificate")?
.1;
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
let left = (cert.validity().not_after.timestamp() as u64).saturating_sub(now);
if left < self.renew_before.as_secs() {
return Ok(Validity::Expires);
}
let mut sans = extract_sans(&cert)?;
sans.sort();
if sans != self.names {
return Ok(Validity::SANMismatch);
}
Ok(Validity::Valid)
}
async fn refresh(&self) -> Result<RefreshResult, Error> {
if self.cert.load_full().is_none() {
let _ = self.load().await;
}
let validity = self.is_valid().await.context("unable to check validity")?;
if validity == Validity::Valid {
debug!("ACME-DNS: Certificate is still valid");
return Ok(RefreshResult::StillValid);
}
debug!("ACME-DNS: Certificate validity is '{validity}', renewing");
let cert = self
.client
.issue(self.names.clone(), None)
.await
.context("unable to issue a certificate")?;
let cert_and_key = [cert.cert, cert.key].concat();
fs::write(self.path.join(FILE_CERT), &cert_and_key)
.await
.context("unable to store certificate")?;
self.load()
.await
.context("unable to load certificate from disk")?;
Ok(RefreshResult::Refreshed)
}
}
impl fmt::Debug for AcmeDns {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AcmeDns")
}
}
impl ResolvesServerCert for AcmeDns {
fn resolve(&self, ch: ClientHello) -> Option<Arc<CertifiedKey>> {
let sni = FQDN::from_str(ch.server_name()?).ok()?;
sni_matches(&sni, &self.domains, self.wildcard).then_some(self.cert.load_full())?
}
}
#[async_trait]
impl Run for AcmeDns {
async fn run(&self, _: CancellationToken) -> Result<(), Error> {
self.refresh()
.await
.context("unable to refresh")
.map(|_| ())
}
}
#[cfg(test)]
mod test {
use fqdn::fqdn;
use tempfile::tempdir;
use super::*;
use crate::{
tests::pebble::{Env, dns::TokenManagerPebble},
tls::extract_sans_der,
};
#[ignore]
#[tokio::test]
async fn test_acme_dns() {
let _ = rustls::crypto::ring::default_provider().install_default();
let pebble_env = Env::new().await;
let dir = tempdir().unwrap();
let token_manager = Arc::new(TokenManagerPebble::new(
format!("http://{}", pebble_env.addr_dns_management())
.parse()
.unwrap(),
));
let resolver = pebble_env.resolver();
let token_manager_dns = Arc::new(TokenManagerDns::new(resolver, token_manager, None));
let opts = Opts {
acme_url: AcmeUrl::Custom(
format!("https://{}/dir", pebble_env.addr_acme())
.parse()
.unwrap(),
),
path: dir.path().to_path_buf(),
domains: vec![fqdn!("foo")],
wildcard: true,
renew_before: Duration::from_secs(30),
account_credentials: None,
token_manager: token_manager_dns,
insecure_tls: true,
};
let acme_dns = AcmeDns::new(opts).await.unwrap();
assert_eq!(acme_dns.refresh().await.unwrap(), RefreshResult::Refreshed);
let cert = acme_dns.cert.load_full().unwrap();
let mut sans = extract_sans_der(cert.end_entity_cert().unwrap()).unwrap();
sans.sort();
assert_eq!(sans, vec!["*.foo", "foo"]);
assert_eq!(acme_dns.refresh().await.unwrap(), RefreshResult::StillValid);
let cert = acme_dns.cert.load_full().unwrap();
let mut sans = extract_sans_der(cert.end_entity_cert().unwrap()).unwrap();
sans.sort();
assert_eq!(sans, vec!["*.foo", "foo"]);
}
}