use rcgen::string::Ia5String;
use rcgen::{
CertificateParams, CustomExtension, DistinguishedName, KeyPair, SanType,
};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use rustls::sign::CertifiedKey;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Clone, Default)]
pub struct AlpnChallengeStore {
inner: Arc<RwLock<HashMap<String, Arc<CertifiedKey>>>>,
}
impl AlpnChallengeStore {
pub fn new() -> Self {
Self::default()
}
pub fn put(&self, sni: String, ck: Arc<CertifiedKey>) {
self.inner
.write()
.unwrap_or_else(|p| p.into_inner())
.insert(sni, ck);
}
pub fn remove(&self, sni: &str) {
self.inner
.write()
.unwrap_or_else(|p| p.into_inner())
.remove(sni);
}
pub fn get(&self, sni: &str) -> Option<Arc<CertifiedKey>> {
self.inner
.read()
.unwrap_or_else(|p| p.into_inner())
.get(sni)
.cloned()
}
pub fn is_empty(&self) -> bool {
self.inner
.read()
.unwrap_or_else(|p| p.into_inner())
.is_empty()
}
}
pub fn build_challenge_cert(
domain: &str,
key_auth_digest: &[u8],
) -> anyhow::Result<Arc<CertifiedKey>> {
const ACME_OID: &[u64] = &[1, 3, 6, 1, 5, 5, 7, 1, 31];
if key_auth_digest.len() != 32 {
anyhow::bail!(
"ACME key authorization digest must be 32 bytes; got {}",
key_auth_digest.len()
);
}
let mut der_value = Vec::with_capacity(2 + 32);
der_value.push(0x04); der_value.push(0x20); der_value.extend_from_slice(key_auth_digest);
let mut ext = CustomExtension::from_oid_content(
ACME_OID,
der_value,
);
ext.set_criticality(true);
let mut params = CertificateParams::new(Vec::<String>::new())?;
params.distinguished_name = DistinguishedName::new();
params.subject_alt_names = vec![SanType::DnsName(
Ia5String::try_from(domain.to_string())
.map_err(|e| anyhow::anyhow!("invalid SAN '{domain}': {e}"))?,
)];
params.custom_extensions.push(ext);
let kp = KeyPair::generate()?;
let cert = params.self_signed(&kp)?;
let cert_der = CertificateDer::from(cert.der().to_vec());
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(
kp.serialize_der(),
));
let signing_key = rustls::crypto::aws_lc_rs::sign::any_supported_type(
&key_der,
)
.map_err(|e| anyhow::anyhow!("loading ACME challenge key: {e}"))?;
Ok(Arc::new(CertifiedKey::new(vec![cert_der], signing_key)))
}
pub struct AlpnAwareResolver {
pub store: AlpnChallengeStore,
pub production: Arc<CertifiedKey>,
}
impl std::fmt::Debug for AlpnAwareResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AlpnAwareResolver")
.field("store_empty", &self.store.is_empty())
.finish()
}
}
impl rustls::server::ResolvesServerCert for AlpnAwareResolver {
fn resolve(
&self,
client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<CertifiedKey>> {
let wants_challenge = client_hello
.alpn()
.map(|mut it| it.any(|p| p == b"acme-tls/1"))
.unwrap_or(false);
if wants_challenge
&& let Some(sni) = client_hello.server_name()
&& let Some(ck) = self.store.get(sni)
{
return Some(ck);
}
Some(self.production.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn store_round_trip() {
let s = AlpnChallengeStore::new();
assert!(s.is_empty());
let ck = build_challenge_cert("foo.example", &[0u8; 32]).unwrap();
s.put("foo.example".into(), ck.clone());
assert!(!s.is_empty());
let got = s.get("foo.example").unwrap();
assert_eq!(got.cert.len(), 1);
s.remove("foo.example");
assert!(s.is_empty());
}
#[test]
fn challenge_cert_contains_acme_extension() {
use x509_parser::prelude::FromDer;
let ck =
build_challenge_cert("foo.example", &[0x42u8; 32]).unwrap();
let der = ck.cert[0].as_ref();
let (_, cert) =
x509_parser::certificate::X509Certificate::from_der(der)
.expect("parse cert");
let acme_ext = cert
.extensions()
.iter()
.find(|e| {
e.oid.as_bytes() == [0x2b, 6, 1, 5, 5, 7, 1, 0x1f]
})
.expect("acmeIdentifier extension");
assert!(acme_ext.critical, "acme extension must be critical");
assert_eq!(acme_ext.value[0], 0x04);
assert_eq!(acme_ext.value[1], 0x20);
assert_eq!(&acme_ext.value[2..], &[0x42u8; 32]);
}
}