use std::sync::Arc;
use std::time::Duration;
use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey;
use scp_core::store::ProtocolStore;
use scp_platform::traits::Storage;
use tokio::sync::RwLock;
use zeroize::Zeroizing;
const RENEWAL_THRESHOLD_DAYS: i64 = 30;
const RENEWAL_CHECK_INTERVAL: Duration = Duration::from_secs(12 * 60 * 60);
#[derive(Debug, thiserror::Error)]
pub enum TlsError {
#[error("ACME error: {0}")]
Acme(String),
#[error("certificate error: {0}")]
Certificate(String),
#[error("storage error: {0}")]
Storage(String),
#[error("TLS config error: {0}")]
Config(String),
#[error("missing required field: {0}")]
MissingField(&'static str),
}
#[derive(Clone)]
pub struct CertificateData {
pub certificate_chain_pem: String,
pub private_key_pem: Zeroizing<String>,
}
impl std::fmt::Debug for CertificateData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CertificateData")
.field("certificate_chain_pem", &self.certificate_chain_pem)
.field("private_key_pem", &"[REDACTED]")
.finish()
}
}
impl CertificateData {
pub fn certificate_chain_der(
&self,
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, TlsError> {
let mut reader = std::io::BufReader::new(self.certificate_chain_pem.as_bytes());
let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| TlsError::Certificate(format!("failed to parse PEM certificates: {e}")))?;
if certs.is_empty() {
return Err(TlsError::Certificate(
"no certificates found in PEM data".to_owned(),
));
}
Ok(certs)
}
pub fn private_key_der(&self) -> Result<rustls::pki_types::PrivateKeyDer<'static>, TlsError> {
let mut reader = std::io::BufReader::new(self.private_key_pem.as_bytes());
rustls_pemfile::private_key(&mut reader)
.map_err(|e| TlsError::Certificate(format!("failed to parse PEM private key: {e}")))?
.ok_or_else(|| TlsError::Certificate("no private key found in PEM data".to_owned()))
}
pub fn expiry_timestamp(&self) -> Result<i64, TlsError> {
let certs = self.certificate_chain_der()?;
let leaf = certs
.first()
.ok_or_else(|| TlsError::Certificate("empty certificate chain".to_owned()))?;
let (_, cert) = x509_parser::parse_x509_certificate(leaf.as_ref()).map_err(|e| {
TlsError::Certificate(format!("failed to parse X.509 certificate: {e}"))
})?;
Ok(cert.validity().not_after.timestamp())
}
pub fn needs_renewal(&self) -> Result<bool, TlsError> {
let expiry = self.expiry_timestamp()?;
let now = scp_core::time::now_secs()
.map_err(|e| TlsError::Certificate(format!("{e}")))?
.cast_signed();
let threshold = RENEWAL_THRESHOLD_DAYS * 24 * 60 * 60;
Ok(expiry - now < threshold)
}
}
pub fn build_tls_server_config(
cert_data: &CertificateData,
) -> Result<rustls::ServerConfig, TlsError> {
let certs = cert_data.certificate_chain_der()?;
let key = cert_data.private_key_der()?;
let provider = Arc::new(rustls::crypto::ring::default_provider());
let config = rustls::ServerConfig::builder_with_provider(provider)
.with_protocol_versions(&[&rustls::version::TLS13])
.map_err(|e| TlsError::Config(format!("failed to set TLS versions: {e}")))?
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| TlsError::Config(format!("failed to set certificate: {e}")))?;
Ok(config)
}
pub fn build_reloadable_tls_config(
cert_data: &CertificateData,
) -> Result<(rustls::ServerConfig, Arc<CertResolver>), TlsError> {
let certs = cert_data.certificate_chain_der()?;
let key = cert_data.private_key_der()?;
let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
.map_err(|e| TlsError::Config(format!("unsupported private key type: {e}")))?;
let certified_key = CertifiedKey::new(certs, signing_key);
let resolver = Arc::new(CertResolver::new(certified_key));
let provider = Arc::new(rustls::crypto::ring::default_provider());
let mut config = rustls::ServerConfig::builder_with_provider(provider)
.with_protocol_versions(&[&rustls::version::TLS13])
.map_err(|e| TlsError::Config(format!("failed to set TLS versions: {e}")))?
.with_no_client_auth()
.with_cert_resolver(resolver.clone() as Arc<dyn ResolvesServerCert>);
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Ok((config, resolver))
}
#[derive(Debug)]
pub struct CertResolver {
pub(crate) inner: std::sync::RwLock<Arc<CertifiedKey>>,
}
impl CertResolver {
#[must_use]
pub fn new(key: CertifiedKey) -> Self {
Self {
inner: std::sync::RwLock::new(Arc::new(key)),
}
}
pub fn update(&self, key: CertifiedKey) {
let mut guard = match self.inner.write() {
Ok(g) => g,
Err(poisoned) => {
tracing::warn!("CertResolver RwLock was poisoned, clearing poison");
poisoned.into_inner()
}
};
*guard = Arc::new(key);
}
}
impl ResolvesServerCert for CertResolver {
fn resolve(&self, _client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
self.inner.read().ok().map(|guard| Arc::clone(&*guard))
}
}
pub struct AcmeProvider<S: Storage> {
domain: String,
storage: Arc<ProtocolStore<S>>,
email: Option<String>,
directory_url: String,
cert_resolver: Option<Arc<CertResolver>>,
challenges: Arc<RwLock<std::collections::HashMap<String, String>>>,
}
impl<S: Storage> std::fmt::Debug for AcmeProvider<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AcmeProvider")
.field("domain", &self.domain)
.field("email", &self.email)
.field("directory_url", &self.directory_url)
.finish_non_exhaustive()
}
}
impl<S: Storage + 'static> AcmeProvider<S> {
#[must_use]
pub fn new(domain: &str, storage: Arc<ProtocolStore<S>>) -> Self {
Self {
domain: domain.to_owned(),
storage,
email: None,
directory_url: "https://acme-v02.api.letsencrypt.org/directory".to_owned(),
cert_resolver: None,
challenges: Arc::new(RwLock::new(std::collections::HashMap::new())),
}
}
#[must_use]
pub fn with_email(mut self, email: &str) -> Self {
self.email = Some(email.to_owned());
self
}
#[must_use]
pub fn with_directory_url(mut self, url: &str) -> Self {
url.clone_into(&mut self.directory_url);
self
}
#[must_use]
pub fn with_cert_resolver(mut self, resolver: Arc<CertResolver>) -> Self {
self.cert_resolver = Some(resolver);
self
}
#[must_use]
pub fn challenges(&self) -> Arc<RwLock<std::collections::HashMap<String, String>>> {
Arc::clone(&self.challenges)
}
async fn load_tls_cert(&self) -> Result<Option<CertificateData>, TlsError> {
match self
.storage
.load_tls_certificate()
.await
.map_err(|e| TlsError::Storage(format!("failed to load certificate: {e}")))?
{
Some((certificate_chain_pem, private_key_pem)) => Ok(Some(CertificateData {
certificate_chain_pem,
private_key_pem,
})),
None => Ok(None),
}
}
pub async fn provision(&self) -> Result<CertificateData, TlsError> {
use instant_acme::{Account, Identifier, NewAccount, NewOrder};
let contacts: Vec<String> = self
.email
.as_ref()
.map(|e| vec![format!("mailto:{e}")])
.unwrap_or_default();
let contact_refs: Vec<&str> = contacts.iter().map(String::as_str).collect();
let account_request = NewAccount {
contact: &contact_refs,
terms_of_service_agreed: true,
only_return_existing: false,
};
let builder = Account::builder()
.map_err(|e| TlsError::Acme(format!("failed to create account builder: {e}")))?;
let (account, _credentials) = builder
.create(&account_request, self.directory_url.clone(), None)
.await
.map_err(|e| TlsError::Acme(format!("failed to create ACME account: {e}")))?;
let identifier = Identifier::Dns(self.domain.clone());
let identifiers = [identifier];
let mut order = account
.new_order(&NewOrder::new(&identifiers))
.await
.map_err(|e| TlsError::Acme(format!("failed to create order: {e}")))?;
{
let mut authorizations = order.authorizations();
let mut auth = authorizations
.next()
.await
.ok_or_else(|| TlsError::Acme("no authorizations returned".to_owned()))?
.map_err(|e| TlsError::Acme(format!("authorization error: {e}")))?;
let mut challenge_handle = auth
.challenge(instant_acme::ChallengeType::Http01)
.ok_or_else(|| TlsError::Acme("no HTTP-01 challenge found".to_owned()))?;
let key_auth = challenge_handle.key_authorization().as_str().to_owned();
let token = challenge_handle.token.clone();
{
let mut map = self.challenges.write().await;
map.insert(token.clone(), key_auth);
}
tracing::debug!(
domain = %self.domain, %token,
"ACME HTTP-01 challenge token stored in challenge map"
);
challenge_handle
.set_ready()
.await
.map_err(|e| TlsError::Acme(format!("failed to set challenge ready: {e}")))?;
}
order
.poll_ready(&instant_acme::RetryPolicy::default())
.await
.map_err(|e| TlsError::Acme(format!("order failed to become ready: {e}")))?;
let private_key_pem = Zeroizing::new(
order
.finalize()
.await
.map_err(|e| TlsError::Acme(format!("failed to finalize order: {e}")))?,
);
let certificate_chain_pem = order
.certificate()
.await
.map_err(|e| TlsError::Acme(format!("failed to download certificate: {e}")))?
.ok_or_else(|| TlsError::Acme("no certificate returned".to_owned()))?;
let cert_data = CertificateData {
certificate_chain_pem,
private_key_pem,
};
self.storage
.store_tls_certificate(&cert_data.certificate_chain_pem, &cert_data.private_key_pem)
.await
.map_err(|e| TlsError::Storage(format!("failed to store certificate: {e}")))?;
{
let mut map = self.challenges.write().await;
map.clear();
}
tracing::info!(domain = %self.domain, "TLS certificate provisioned via ACME");
Ok(cert_data)
}
pub async fn load_or_provision(&self) -> Result<CertificateData, TlsError> {
if let Some(cert_data) = self.load_tls_cert().await? {
if !cert_data.needs_renewal()? {
tracing::info!(domain = %self.domain, "loaded existing TLS certificate from storage");
return Ok(cert_data);
}
tracing::info!(domain = %self.domain, "existing certificate needs renewal");
}
self.provision().await
}
#[must_use]
pub fn start_renewal_loop(self: Arc<Self>) -> tokio::task::JoinHandle<()>
where
S: Send + Sync + 'static,
{
tokio::spawn(async move {
loop {
tokio::time::sleep(RENEWAL_CHECK_INTERVAL).await;
match self.load_tls_cert().await {
Ok(Some(cert_data)) => match cert_data.needs_renewal() {
Ok(true) => {
tracing::info!(
domain = %self.domain,
"certificate approaching expiry, renewing"
);
match self.provision().await {
Ok(new_cert) => {
if let Some(resolver) = &self.cert_resolver
&& let Ok(certs) = new_cert.certificate_chain_der()
&& let Ok(key) = new_cert.private_key_der()
&& let Ok(signing_key) =
rustls::crypto::ring::sign::any_supported_type(&key)
{
let certified = CertifiedKey::new(certs, signing_key);
resolver.update(certified);
tracing::info!(
domain = %self.domain,
"TLS certificate renewed and hot-reloaded"
);
}
}
Err(e) => {
tracing::error!(
domain = %self.domain,
error = %e,
"failed to renew TLS certificate"
);
}
}
}
Ok(false) => {
tracing::debug!(
domain = %self.domain,
"certificate not yet due for renewal"
);
}
Err(e) => {
tracing::warn!(
domain = %self.domain,
error = %e,
"failed to check certificate expiry"
);
}
},
Ok(None) => {
tracing::warn!(
domain = %self.domain,
"no certificate in storage; skipping renewal check"
);
}
Err(e) => {
tracing::error!(
domain = %self.domain,
error = %e,
"failed to load certificate for renewal check"
);
}
}
}
})
}
}
#[allow(clippy::implicit_hasher)]
pub fn acme_challenge_router(
challenges: Arc<RwLock<std::collections::HashMap<String, String>>>,
) -> axum::Router {
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::IntoResponse;
async fn handle_challenge(
State(challenges): State<Arc<RwLock<std::collections::HashMap<String, String>>>>,
Path(token): Path<String>,
) -> impl IntoResponse {
let map = challenges.read().await;
map.get(&token).map_or_else(
|| {
(
StatusCode::NOT_FOUND,
[(axum::http::header::CONTENT_TYPE, "text/plain")],
String::new(),
)
},
|key_auth| {
(
StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "text/plain")],
key_auth.clone(),
)
},
)
}
axum::Router::new()
.route(
"/.well-known/acme-challenge/{token}",
axum::routing::get(handle_challenge),
)
.with_state(challenges)
}
pub async fn serve_tls(
listener: tokio::net::TcpListener,
tls_config: Arc<rustls::ServerConfig>,
app: axum::Router,
shutdown_token: tokio_util::sync::CancellationToken,
) -> Result<(), crate::NodeError> {
use axum::extract::Request;
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
use tower_service::Service;
let tls_acceptor = tokio_rustls::TlsAcceptor::from(tls_config);
let connection_tracker = Arc::new(tokio::sync::Notify::new());
let active_connections = Arc::new(std::sync::atomic::AtomicUsize::new(0));
loop {
let (tcp_stream, peer_addr) = tokio::select! {
biased;
() = shutdown_token.cancelled() => {
tracing::info!("TLS server shutting down, draining in-flight connections");
let drain_start = tokio::time::Instant::now();
let drain_timeout = Duration::from_secs(30);
loop {
let count = active_connections.load(std::sync::atomic::Ordering::Relaxed);
if count == 0 {
tracing::info!("all connections drained");
break;
}
if drain_start.elapsed() >= drain_timeout {
tracing::warn!(
remaining = count,
"drain timeout reached (30s), {count} connections still active"
);
break;
}
let remaining = drain_timeout.saturating_sub(drain_start.elapsed());
let _ = tokio::time::timeout(remaining, connection_tracker.notified()).await;
}
return Ok(());
}
result = listener.accept() => {
match result {
Ok(pair) => pair,
Err(e) => {
tracing::warn!(error = %e, "TCP accept error");
continue;
}
}
}
};
let tls_acceptor = tls_acceptor.clone();
let tower_service = app.clone();
let active = Arc::clone(&active_connections);
let notify = Arc::clone(&connection_tracker);
active.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tokio::spawn(async move {
let tls_stream = match tokio::time::timeout(
Duration::from_secs(10),
tls_acceptor.accept(tcp_stream),
)
.await
{
Ok(Ok(stream)) => stream,
Ok(Err(e)) => {
tracing::debug!(
peer = %peer_addr,
error = %e,
"TLS handshake failed"
);
active.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
notify.notify_waiters();
return;
}
Err(_elapsed) => {
tracing::debug!(
peer = %peer_addr,
"TLS handshake timed out (10s)"
);
active.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
notify.notify_waiters();
return;
}
};
let io = TokioIo::new(tls_stream);
let hyper_service = hyper::service::service_fn(move |mut req: Request<Incoming>| {
req.extensions_mut()
.insert(axum::extract::ConnectInfo(peer_addr));
tower_service.clone().call(req)
});
let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
builder.http2().max_concurrent_streams(100);
let result = builder
.serve_connection_with_upgrades(io, hyper_service)
.await;
if let Err(e) = result {
tracing::debug!(
peer = %peer_addr,
error = %e,
"connection error"
);
}
active.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
notify.notify_waiters();
});
}
}
pub fn generate_self_signed(domain: &str) -> Result<CertificateData, TlsError> {
let mut params = rcgen::CertificateParams::new(vec![domain.to_owned()])
.map_err(|e| TlsError::Certificate(format!("failed to create cert params: {e}")))?;
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, domain);
let key_pair = rcgen::KeyPair::generate()
.map_err(|e| TlsError::Certificate(format!("failed to generate key pair: {e}")))?;
let cert = params
.self_signed(&key_pair)
.map_err(|e| TlsError::Certificate(format!("failed to generate self-signed cert: {e}")))?;
Ok(CertificateData {
certificate_chain_pem: cert.pem(),
private_key_pem: Zeroizing::new(key_pair.serialize_pem()),
})
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::similar_names,
clippy::cast_possible_wrap,
clippy::significant_drop_tightening
)]
mod tests {
use super::*;
use scp_platform::testing::InMemoryStorage;
#[test]
fn generate_self_signed_produces_valid_pem() {
let cert = generate_self_signed("test.example.com").unwrap();
assert!(cert.certificate_chain_pem.contains("BEGIN CERTIFICATE"));
assert!(cert.private_key_pem.contains("BEGIN PRIVATE KEY"));
}
#[test]
fn certificate_chain_der_parses_pem() {
let cert = generate_self_signed("test.example.com").unwrap();
let der_certs = cert.certificate_chain_der().unwrap();
assert_eq!(
der_certs.len(),
1,
"self-signed should have exactly one cert"
);
}
#[test]
fn private_key_der_parses_pem() {
let cert = generate_self_signed("test.example.com").unwrap();
let _key = cert.private_key_der().unwrap();
}
#[test]
fn expiry_timestamp_is_in_future() {
let cert = generate_self_signed("test.example.com").unwrap();
let expiry = cert.expiry_timestamp().unwrap();
let now = scp_core::time::now_secs().expect("clock unavailable in test") as i64;
assert!(expiry > now, "self-signed cert should expire in the future");
}
#[test]
fn fresh_self_signed_does_not_need_renewal() {
let cert = generate_self_signed("test.example.com").unwrap();
assert!(
!cert.needs_renewal().unwrap(),
"a freshly generated cert should not need renewal"
);
}
#[test]
fn build_tls_server_config_enforces_tls_13() {
let cert = generate_self_signed("test.example.com").unwrap();
let config = build_tls_server_config(&cert).unwrap();
assert!(
config.alpn_protocols.is_empty(),
"basic config should not set ALPN"
);
let _acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
}
#[test]
fn build_reloadable_tls_config_returns_resolver() {
let cert = generate_self_signed("test.example.com").unwrap();
let (config, resolver) = build_reloadable_tls_config(&cert).unwrap();
assert!(
!config.alpn_protocols.is_empty(),
"reloadable config should set ALPN"
);
let guard = resolver.inner.try_read().unwrap();
assert!(!guard.cert.is_empty(), "resolver should have certificates");
}
#[tokio::test]
async fn cert_resolver_update_swaps_certificate() {
let cert1 = generate_self_signed("one.example.com").unwrap();
let cert2 = generate_self_signed("two.example.com").unwrap();
let certs1 = cert1.certificate_chain_der().unwrap();
let key1 = cert1.private_key_der().unwrap();
let signing1 = rustls::crypto::ring::sign::any_supported_type(&key1).unwrap();
let ck1 = CertifiedKey::new(certs1.clone(), signing1);
let certs2 = cert2.certificate_chain_der().unwrap();
let key2 = cert2.private_key_der().unwrap();
let signing2 = rustls::crypto::ring::sign::any_supported_type(&key2).unwrap();
let ck2 = CertifiedKey::new(certs2.clone(), signing2);
let resolver = CertResolver::new(ck1);
{
let guard = resolver.inner.read().unwrap();
assert_eq!(guard.cert.len(), certs1.len());
}
resolver.update(ck2);
{
let guard = resolver.inner.read().unwrap();
assert_eq!(guard.cert.len(), certs2.len());
}
}
#[tokio::test]
async fn certificate_storage_roundtrip() {
let store = ProtocolStore::new_for_testing(InMemoryStorage::new());
let original = generate_self_signed("roundtrip.example.com").unwrap();
store
.store_tls_certificate(&original.certificate_chain_pem, &original.private_key_pem)
.await
.unwrap();
let (cert, key) = store.load_tls_certificate().await.unwrap().unwrap();
assert_eq!(cert, original.certificate_chain_pem);
assert_eq!(key, original.private_key_pem);
}
#[tokio::test]
async fn load_certificate_returns_none_when_empty() {
let store = ProtocolStore::new_for_testing(InMemoryStorage::new());
let result = store.load_tls_certificate().await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn acme_challenge_router_serves_token() {
use axum::body::Body;
use http_body_util::BodyExt;
use tower::ServiceExt;
let challenges = Arc::new(RwLock::new(std::collections::HashMap::new()));
{
let mut map = challenges.write().await;
map.insert("test-token".to_owned(), "test-key-auth".to_owned());
}
let router = acme_challenge_router(challenges);
let request = axum::http::Request::builder()
.uri("/.well-known/acme-challenge/test-token")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::OK);
let content_type = response
.headers()
.get("content-type")
.expect("should have Content-Type header")
.to_str()
.unwrap();
assert_eq!(content_type, "text/plain");
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"test-key-auth");
}
#[tokio::test]
async fn acme_challenge_router_returns_404_for_unknown_token() {
use axum::body::Body;
use tower::ServiceExt;
let challenges = Arc::new(RwLock::new(std::collections::HashMap::new()));
let router = acme_challenge_router(challenges);
let request = axum::http::Request::builder()
.uri("/.well-known/acme-challenge/unknown")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::NOT_FOUND);
}
#[test]
fn acme_provider_new_sets_defaults() {
let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
let provider = AcmeProvider::new("example.com", storage);
assert_eq!(provider.domain, "example.com");
assert!(provider.email.is_none());
assert!(provider.directory_url.contains("letsencrypt"));
}
#[test]
fn acme_provider_with_email() {
let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
let provider = AcmeProvider::new("example.com", storage).with_email("admin@example.com");
assert_eq!(provider.email.as_deref(), Some("admin@example.com"));
}
#[test]
fn acme_provider_with_directory_url() {
let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
let provider = AcmeProvider::new("example.com", storage)
.with_directory_url("https://acme-staging-v02.api.letsencrypt.org/directory");
assert!(provider.directory_url.contains("staging"));
}
#[test]
fn acme_provider_with_cert_resolver() {
let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
let cert = generate_self_signed("example.com").unwrap();
let certs = cert.certificate_chain_der().unwrap();
let key = cert.private_key_der().unwrap();
let signing = rustls::crypto::ring::sign::any_supported_type(&key).unwrap();
let ck = CertifiedKey::new(certs, signing);
let resolver = Arc::new(CertResolver::new(ck));
let provider =
AcmeProvider::new("example.com", storage).with_cert_resolver(Arc::clone(&resolver));
assert!(provider.cert_resolver.is_some());
}
#[test]
fn acme_provider_challenges_returns_shared_map() {
let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
let provider = AcmeProvider::new("example.com", storage);
let challenges_a = provider.challenges();
let challenges_b = provider.challenges();
assert!(Arc::ptr_eq(&challenges_a, &challenges_b));
}
#[tokio::test]
async fn acme_challenge_router_serves_from_shared_map() {
use axum::body::Body;
use http_body_util::BodyExt;
use tower::ServiceExt;
let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
let provider = AcmeProvider::new("example.com", storage);
let challenges = provider.challenges();
{
let mut map = challenges.write().await;
map.insert("acme-token-abc".to_owned(), "key-auth-xyz".to_owned());
}
let router = acme_challenge_router(Arc::clone(&challenges));
let request = axum::http::Request::builder()
.uri("/.well-known/acme-challenge/acme-token-abc")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::OK);
let content_type = response
.headers()
.get("content-type")
.expect("should have Content-Type header")
.to_str()
.unwrap();
assert_eq!(content_type, "text/plain");
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"key-auth-xyz");
}
#[tokio::test]
async fn provision_without_acme_server_returns_error() {
let _ = rustls::crypto::ring::default_provider().install_default();
let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
let provider = AcmeProvider::new("test.example.com", storage)
.with_directory_url("http://127.0.0.1:1/nonexistent");
let result =
tokio::time::timeout(std::time::Duration::from_secs(10), provider.provision()).await;
let provision_result = result.expect("provision() should not hang");
assert!(
provision_result.is_err(),
"provision() without ACME server should return TlsError"
);
}
#[tokio::test]
async fn acme_challenge_pipeline_end_to_end() {
use axum::body::Body;
use http_body_util::BodyExt;
use tower::ServiceExt;
let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
let provider = AcmeProvider::new("test.example.com", storage);
let challenges = provider.challenges();
{
let mut map = challenges.write().await;
map.insert(
"simulated-token".to_owned(),
"simulated-key-auth".to_owned(),
);
}
let router = acme_challenge_router(provider.challenges());
let request = axum::http::Request::builder()
.uri("/.well-known/acme-challenge/simulated-token")
.body(Body::empty())
.unwrap();
let response = router.clone().oneshot(request).await.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::OK);
let content_type = response
.headers()
.get("content-type")
.expect("should have Content-Type header")
.to_str()
.unwrap();
assert_eq!(content_type, "text/plain");
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"simulated-key-auth");
let request_404 = axum::http::Request::builder()
.uri("/.well-known/acme-challenge/unknown-token")
.body(Body::empty())
.unwrap();
let response_404 = router.oneshot(request_404).await.unwrap();
assert_eq!(response_404.status(), axum::http::StatusCode::NOT_FOUND);
{
let mut map = challenges.write().await;
map.clear();
}
let router_after_clear = acme_challenge_router(provider.challenges());
let request_cleared = axum::http::Request::builder()
.uri("/.well-known/acme-challenge/simulated-token")
.body(Body::empty())
.unwrap();
let response_cleared = router_after_clear.oneshot(request_cleared).await.unwrap();
assert_eq!(
response_cleared.status(),
axum::http::StatusCode::NOT_FOUND,
"cleared challenge map should no longer serve token"
);
}
}