1use std::sync::Arc;
18use std::time::Duration;
19
20use rustls::server::ResolvesServerCert;
21use rustls::sign::CertifiedKey;
22use scp_core::store::ProtocolStore;
23use scp_platform::traits::Storage;
24use tokio::sync::RwLock;
25use zeroize::Zeroizing;
26
27const RENEWAL_THRESHOLD_DAYS: i64 = 30;
33
34const RENEWAL_CHECK_INTERVAL: Duration = Duration::from_secs(12 * 60 * 60); #[derive(Debug, thiserror::Error)]
43pub enum TlsError {
44 #[error("ACME error: {0}")]
46 Acme(String),
47
48 #[error("certificate error: {0}")]
50 Certificate(String),
51
52 #[error("storage error: {0}")]
54 Storage(String),
55
56 #[error("TLS config error: {0}")]
58 Config(String),
59
60 #[error("missing required field: {0}")]
62 MissingField(&'static str),
63}
64
65#[derive(Clone)]
79pub struct CertificateData {
80 pub certificate_chain_pem: String,
82 pub private_key_pem: Zeroizing<String>,
85}
86
87impl std::fmt::Debug for CertificateData {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("CertificateData")
90 .field("certificate_chain_pem", &self.certificate_chain_pem)
91 .field("private_key_pem", &"[REDACTED]")
92 .finish()
93 }
94}
95
96impl CertificateData {
97 pub fn certificate_chain_der(
103 &self,
104 ) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, TlsError> {
105 let mut reader = std::io::BufReader::new(self.certificate_chain_pem.as_bytes());
106 let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
107 .collect::<Result<Vec<_>, _>>()
108 .map_err(|e| TlsError::Certificate(format!("failed to parse PEM certificates: {e}")))?;
109
110 if certs.is_empty() {
111 return Err(TlsError::Certificate(
112 "no certificates found in PEM data".to_owned(),
113 ));
114 }
115
116 Ok(certs)
117 }
118
119 pub fn private_key_der(&self) -> Result<rustls::pki_types::PrivateKeyDer<'static>, TlsError> {
125 let mut reader = std::io::BufReader::new(self.private_key_pem.as_bytes());
126 rustls_pemfile::private_key(&mut reader)
127 .map_err(|e| TlsError::Certificate(format!("failed to parse PEM private key: {e}")))?
128 .ok_or_else(|| TlsError::Certificate("no private key found in PEM data".to_owned()))
129 }
130
131 pub fn expiry_timestamp(&self) -> Result<i64, TlsError> {
140 let certs = self.certificate_chain_der()?;
141 let leaf = certs
142 .first()
143 .ok_or_else(|| TlsError::Certificate("empty certificate chain".to_owned()))?;
144
145 let (_, cert) = x509_parser::parse_x509_certificate(leaf.as_ref()).map_err(|e| {
146 TlsError::Certificate(format!("failed to parse X.509 certificate: {e}"))
147 })?;
148
149 Ok(cert.validity().not_after.timestamp())
150 }
151
152 pub fn needs_renewal(&self) -> Result<bool, TlsError> {
158 let expiry = self.expiry_timestamp()?;
159 let now = scp_core::time::now_secs()
160 .map_err(|e| TlsError::Certificate(format!("{e}")))?
161 .cast_signed();
162
163 let threshold = RENEWAL_THRESHOLD_DAYS * 24 * 60 * 60;
164 Ok(expiry - now < threshold)
165 }
166}
167
168pub fn build_tls_server_config(
181 cert_data: &CertificateData,
182) -> Result<rustls::ServerConfig, TlsError> {
183 let certs = cert_data.certificate_chain_der()?;
184 let key = cert_data.private_key_der()?;
185
186 let provider = Arc::new(rustls::crypto::ring::default_provider());
187 let config = rustls::ServerConfig::builder_with_provider(provider)
188 .with_protocol_versions(&[&rustls::version::TLS13])
189 .map_err(|e| TlsError::Config(format!("failed to set TLS versions: {e}")))?
190 .with_no_client_auth()
191 .with_single_cert(certs, key)
192 .map_err(|e| TlsError::Config(format!("failed to set certificate: {e}")))?;
193
194 Ok(config)
195}
196
197pub fn build_reloadable_tls_config(
207 cert_data: &CertificateData,
208) -> Result<(rustls::ServerConfig, Arc<CertResolver>), TlsError> {
209 let certs = cert_data.certificate_chain_der()?;
210 let key = cert_data.private_key_der()?;
211
212 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
213 .map_err(|e| TlsError::Config(format!("unsupported private key type: {e}")))?;
214
215 let certified_key = CertifiedKey::new(certs, signing_key);
216 let resolver = Arc::new(CertResolver::new(certified_key));
217
218 let provider = Arc::new(rustls::crypto::ring::default_provider());
219 let mut config = rustls::ServerConfig::builder_with_provider(provider)
220 .with_protocol_versions(&[&rustls::version::TLS13])
221 .map_err(|e| TlsError::Config(format!("failed to set TLS versions: {e}")))?
222 .with_no_client_auth()
223 .with_cert_resolver(resolver.clone() as Arc<dyn ResolvesServerCert>);
224
225 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
226
227 Ok((config, resolver))
228}
229
230#[derive(Debug)]
246pub struct CertResolver {
247 pub(crate) inner: std::sync::RwLock<Arc<CertifiedKey>>,
249}
250
251impl CertResolver {
252 #[must_use]
254 pub fn new(key: CertifiedKey) -> Self {
255 Self {
256 inner: std::sync::RwLock::new(Arc::new(key)),
257 }
258 }
259
260 pub fn update(&self, key: CertifiedKey) {
270 let mut guard = match self.inner.write() {
271 Ok(g) => g,
272 Err(poisoned) => {
273 tracing::warn!("CertResolver RwLock was poisoned, clearing poison");
274 poisoned.into_inner()
275 }
276 };
277 *guard = Arc::new(key);
278 }
279}
280
281impl ResolvesServerCert for CertResolver {
282 fn resolve(&self, _client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
283 self.inner.read().ok().map(|guard| Arc::clone(&*guard))
284 }
285}
286
287pub struct AcmeProvider<S: Storage> {
307 domain: String,
309 storage: Arc<ProtocolStore<S>>,
311 email: Option<String>,
313 directory_url: String,
315 cert_resolver: Option<Arc<CertResolver>>,
317 challenges: Arc<RwLock<std::collections::HashMap<String, String>>>,
324}
325
326impl<S: Storage> std::fmt::Debug for AcmeProvider<S> {
327 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328 f.debug_struct("AcmeProvider")
329 .field("domain", &self.domain)
330 .field("email", &self.email)
331 .field("directory_url", &self.directory_url)
332 .finish_non_exhaustive()
333 }
334}
335
336impl<S: Storage + 'static> AcmeProvider<S> {
337 #[must_use]
342 pub fn new(domain: &str, storage: Arc<ProtocolStore<S>>) -> Self {
343 Self {
344 domain: domain.to_owned(),
345 storage,
346 email: None,
347 directory_url: "https://acme-v02.api.letsencrypt.org/directory".to_owned(),
348 cert_resolver: None,
349 challenges: Arc::new(RwLock::new(std::collections::HashMap::new())),
350 }
351 }
352
353 #[must_use]
355 pub fn with_email(mut self, email: &str) -> Self {
356 self.email = Some(email.to_owned());
357 self
358 }
359
360 #[must_use]
362 pub fn with_directory_url(mut self, url: &str) -> Self {
363 url.clone_into(&mut self.directory_url);
364 self
365 }
366
367 #[must_use]
369 pub fn with_cert_resolver(mut self, resolver: Arc<CertResolver>) -> Self {
370 self.cert_resolver = Some(resolver);
371 self
372 }
373
374 #[must_use]
380 pub fn challenges(&self) -> Arc<RwLock<std::collections::HashMap<String, String>>> {
381 Arc::clone(&self.challenges)
382 }
383
384 async fn load_tls_cert(&self) -> Result<Option<CertificateData>, TlsError> {
387 match self
388 .storage
389 .load_tls_certificate()
390 .await
391 .map_err(|e| TlsError::Storage(format!("failed to load certificate: {e}")))?
392 {
393 Some((certificate_chain_pem, private_key_pem)) => Ok(Some(CertificateData {
394 certificate_chain_pem,
395 private_key_pem,
396 })),
397 None => Ok(None),
398 }
399 }
400
401 pub async fn provision(&self) -> Result<CertificateData, TlsError> {
418 use instant_acme::{Account, Identifier, NewAccount, NewOrder};
419
420 let contacts: Vec<String> = self
422 .email
423 .as_ref()
424 .map(|e| vec![format!("mailto:{e}")])
425 .unwrap_or_default();
426
427 let contact_refs: Vec<&str> = contacts.iter().map(String::as_str).collect();
428
429 let account_request = NewAccount {
430 contact: &contact_refs,
431 terms_of_service_agreed: true,
432 only_return_existing: false,
433 };
434
435 let builder = Account::builder()
436 .map_err(|e| TlsError::Acme(format!("failed to create account builder: {e}")))?;
437
438 let (account, _credentials) = builder
439 .create(&account_request, self.directory_url.clone(), None)
440 .await
441 .map_err(|e| TlsError::Acme(format!("failed to create ACME account: {e}")))?;
442
443 let identifier = Identifier::Dns(self.domain.clone());
445 let identifiers = [identifier];
446 let mut order = account
447 .new_order(&NewOrder::new(&identifiers))
448 .await
449 .map_err(|e| TlsError::Acme(format!("failed to create order: {e}")))?;
450
451 {
455 let mut authorizations = order.authorizations();
456
457 let mut auth = authorizations
458 .next()
459 .await
460 .ok_or_else(|| TlsError::Acme("no authorizations returned".to_owned()))?
461 .map_err(|e| TlsError::Acme(format!("authorization error: {e}")))?;
462
463 let mut challenge_handle = auth
464 .challenge(instant_acme::ChallengeType::Http01)
465 .ok_or_else(|| TlsError::Acme("no HTTP-01 challenge found".to_owned()))?;
466
467 let key_auth = challenge_handle.key_authorization().as_str().to_owned();
468 let token = challenge_handle.token.clone();
469
470 {
474 let mut map = self.challenges.write().await;
475 map.insert(token.clone(), key_auth);
476 }
477
478 tracing::debug!(
479 domain = %self.domain, %token,
480 "ACME HTTP-01 challenge token stored in challenge map"
481 );
482
483 challenge_handle
485 .set_ready()
486 .await
487 .map_err(|e| TlsError::Acme(format!("failed to set challenge ready: {e}")))?;
488 }
489
490 order
492 .poll_ready(&instant_acme::RetryPolicy::default())
493 .await
494 .map_err(|e| TlsError::Acme(format!("order failed to become ready: {e}")))?;
495
496 let private_key_pem = Zeroizing::new(
498 order
499 .finalize()
500 .await
501 .map_err(|e| TlsError::Acme(format!("failed to finalize order: {e}")))?,
502 );
503
504 let certificate_chain_pem = order
506 .certificate()
507 .await
508 .map_err(|e| TlsError::Acme(format!("failed to download certificate: {e}")))?
509 .ok_or_else(|| TlsError::Acme("no certificate returned".to_owned()))?;
510
511 let cert_data = CertificateData {
512 certificate_chain_pem,
513 private_key_pem,
514 };
515
516 self.storage
518 .store_tls_certificate(&cert_data.certificate_chain_pem, &cert_data.private_key_pem)
519 .await
520 .map_err(|e| TlsError::Storage(format!("failed to store certificate: {e}")))?;
521
522 {
526 let mut map = self.challenges.write().await;
527 map.clear();
528 }
529
530 tracing::info!(domain = %self.domain, "TLS certificate provisioned via ACME");
531
532 Ok(cert_data)
533 }
534
535 pub async fn load_or_provision(&self) -> Result<CertificateData, TlsError> {
542 if let Some(cert_data) = self.load_tls_cert().await? {
543 if !cert_data.needs_renewal()? {
544 tracing::info!(domain = %self.domain, "loaded existing TLS certificate from storage");
545 return Ok(cert_data);
546 }
547 tracing::info!(domain = %self.domain, "existing certificate needs renewal");
548 }
549
550 self.provision().await
551 }
552
553 #[must_use]
559 pub fn start_renewal_loop(self: Arc<Self>) -> tokio::task::JoinHandle<()>
560 where
561 S: Send + Sync + 'static,
562 {
563 tokio::spawn(async move {
564 loop {
565 tokio::time::sleep(RENEWAL_CHECK_INTERVAL).await;
566
567 match self.load_tls_cert().await {
568 Ok(Some(cert_data)) => match cert_data.needs_renewal() {
569 Ok(true) => {
570 tracing::info!(
571 domain = %self.domain,
572 "certificate approaching expiry, renewing"
573 );
574 match self.provision().await {
575 Ok(new_cert) => {
576 if let Some(resolver) = &self.cert_resolver
578 && let Ok(certs) = new_cert.certificate_chain_der()
579 && let Ok(key) = new_cert.private_key_der()
580 && let Ok(signing_key) =
581 rustls::crypto::ring::sign::any_supported_type(&key)
582 {
583 let certified = CertifiedKey::new(certs, signing_key);
584 resolver.update(certified);
585 tracing::info!(
586 domain = %self.domain,
587 "TLS certificate renewed and hot-reloaded"
588 );
589 }
590 }
591 Err(e) => {
592 tracing::error!(
593 domain = %self.domain,
594 error = %e,
595 "failed to renew TLS certificate"
596 );
597 }
598 }
599 }
600 Ok(false) => {
601 tracing::debug!(
602 domain = %self.domain,
603 "certificate not yet due for renewal"
604 );
605 }
606 Err(e) => {
607 tracing::warn!(
608 domain = %self.domain,
609 error = %e,
610 "failed to check certificate expiry"
611 );
612 }
613 },
614 Ok(None) => {
615 tracing::warn!(
616 domain = %self.domain,
617 "no certificate in storage; skipping renewal check"
618 );
619 }
620 Err(e) => {
621 tracing::error!(
622 domain = %self.domain,
623 error = %e,
624 "failed to load certificate for renewal check"
625 );
626 }
627 }
628 }
629 })
630 }
631}
632
633#[allow(clippy::implicit_hasher)]
647pub fn acme_challenge_router(
648 challenges: Arc<RwLock<std::collections::HashMap<String, String>>>,
649) -> axum::Router {
650 use axum::extract::{Path, State};
651 use axum::http::StatusCode;
652 use axum::response::IntoResponse;
653
654 async fn handle_challenge(
655 State(challenges): State<Arc<RwLock<std::collections::HashMap<String, String>>>>,
656 Path(token): Path<String>,
657 ) -> impl IntoResponse {
658 let map = challenges.read().await;
659 map.get(&token).map_or_else(
660 || {
661 (
662 StatusCode::NOT_FOUND,
663 [(axum::http::header::CONTENT_TYPE, "text/plain")],
664 String::new(),
665 )
666 },
667 |key_auth| {
668 (
669 StatusCode::OK,
670 [(axum::http::header::CONTENT_TYPE, "text/plain")],
671 key_auth.clone(),
672 )
673 },
674 )
675 }
676
677 axum::Router::new()
678 .route(
679 "/.well-known/acme-challenge/{token}",
680 axum::routing::get(handle_challenge),
681 )
682 .with_state(challenges)
683}
684
685pub async fn serve_tls(
708 listener: tokio::net::TcpListener,
709 tls_config: Arc<rustls::ServerConfig>,
710 app: axum::Router,
711 shutdown_token: tokio_util::sync::CancellationToken,
712) -> Result<(), crate::NodeError> {
713 use axum::extract::Request;
714 use hyper::body::Incoming;
715 use hyper_util::rt::{TokioExecutor, TokioIo};
716 use tower_service::Service;
717
718 let tls_acceptor = tokio_rustls::TlsAcceptor::from(tls_config);
719
720 let connection_tracker = Arc::new(tokio::sync::Notify::new());
723 let active_connections = Arc::new(std::sync::atomic::AtomicUsize::new(0));
724
725 loop {
726 let (tcp_stream, peer_addr) = tokio::select! {
728 biased;
729 () = shutdown_token.cancelled() => {
730 tracing::info!("TLS server shutting down, draining in-flight connections");
731 let drain_start = tokio::time::Instant::now();
734 let drain_timeout = Duration::from_secs(30);
735 loop {
736 let count = active_connections.load(std::sync::atomic::Ordering::Relaxed);
737 if count == 0 {
738 tracing::info!("all connections drained");
739 break;
740 }
741 if drain_start.elapsed() >= drain_timeout {
742 tracing::warn!(
743 remaining = count,
744 "drain timeout reached (30s), {count} connections still active"
745 );
746 break;
747 }
748 let remaining = drain_timeout.saturating_sub(drain_start.elapsed());
750 let _ = tokio::time::timeout(remaining, connection_tracker.notified()).await;
751 }
752 return Ok(());
753 }
754 result = listener.accept() => {
755 match result {
756 Ok(pair) => pair,
757 Err(e) => {
758 tracing::warn!(error = %e, "TCP accept error");
761 continue;
762 }
763 }
764 }
765 };
766
767 let tls_acceptor = tls_acceptor.clone();
768 let tower_service = app.clone();
769 let active = Arc::clone(&active_connections);
770 let notify = Arc::clone(&connection_tracker);
771 active.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
772
773 tokio::spawn(async move {
774 let tls_stream = match tokio::time::timeout(
776 Duration::from_secs(10),
777 tls_acceptor.accept(tcp_stream),
778 )
779 .await
780 {
781 Ok(Ok(stream)) => stream,
782 Ok(Err(e)) => {
783 tracing::debug!(
784 peer = %peer_addr,
785 error = %e,
786 "TLS handshake failed"
787 );
788 active.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
789 notify.notify_waiters();
790 return;
791 }
792 Err(_elapsed) => {
793 tracing::debug!(
794 peer = %peer_addr,
795 "TLS handshake timed out (10s)"
796 );
797 active.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
798 notify.notify_waiters();
799 return;
800 }
801 };
802
803 let io = TokioIo::new(tls_stream);
805
806 let hyper_service = hyper::service::service_fn(move |mut req: Request<Incoming>| {
809 req.extensions_mut()
810 .insert(axum::extract::ConnectInfo(peer_addr));
811 tower_service.clone().call(req)
812 });
813
814 let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
817 builder.http2().max_concurrent_streams(100);
818 let result = builder
819 .serve_connection_with_upgrades(io, hyper_service)
820 .await;
821
822 if let Err(e) = result {
823 tracing::debug!(
825 peer = %peer_addr,
826 error = %e,
827 "connection error"
828 );
829 }
830
831 active.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
832 notify.notify_waiters();
833 });
834 }
835}
836
837pub fn generate_self_signed(domain: &str) -> Result<CertificateData, TlsError> {
850 let mut params = rcgen::CertificateParams::new(vec![domain.to_owned()])
851 .map_err(|e| TlsError::Certificate(format!("failed to create cert params: {e}")))?;
852 params.distinguished_name = rcgen::DistinguishedName::new();
853 params
854 .distinguished_name
855 .push(rcgen::DnType::CommonName, domain);
856
857 let key_pair = rcgen::KeyPair::generate()
858 .map_err(|e| TlsError::Certificate(format!("failed to generate key pair: {e}")))?;
859
860 let cert = params
861 .self_signed(&key_pair)
862 .map_err(|e| TlsError::Certificate(format!("failed to generate self-signed cert: {e}")))?;
863
864 Ok(CertificateData {
865 certificate_chain_pem: cert.pem(),
866 private_key_pem: Zeroizing::new(key_pair.serialize_pem()),
867 })
868}
869
870#[cfg(test)]
875#[allow(
876 clippy::unwrap_used,
877 clippy::expect_used,
878 clippy::panic,
879 clippy::similar_names,
880 clippy::cast_possible_wrap,
881 clippy::significant_drop_tightening
882)]
883mod tests {
884 use super::*;
885 use scp_platform::testing::InMemoryStorage;
886
887 #[test]
890 fn generate_self_signed_produces_valid_pem() {
891 let cert = generate_self_signed("test.example.com").unwrap();
892 assert!(cert.certificate_chain_pem.contains("BEGIN CERTIFICATE"));
893 assert!(cert.private_key_pem.contains("BEGIN PRIVATE KEY"));
894 }
895
896 #[test]
897 fn certificate_chain_der_parses_pem() {
898 let cert = generate_self_signed("test.example.com").unwrap();
899 let der_certs = cert.certificate_chain_der().unwrap();
900 assert_eq!(
901 der_certs.len(),
902 1,
903 "self-signed should have exactly one cert"
904 );
905 }
906
907 #[test]
908 fn private_key_der_parses_pem() {
909 let cert = generate_self_signed("test.example.com").unwrap();
910 let _key = cert.private_key_der().unwrap();
911 }
912
913 #[test]
914 fn expiry_timestamp_is_in_future() {
915 let cert = generate_self_signed("test.example.com").unwrap();
916 let expiry = cert.expiry_timestamp().unwrap();
917 let now = scp_core::time::now_secs().expect("clock unavailable in test") as i64;
918 assert!(expiry > now, "self-signed cert should expire in the future");
919 }
920
921 #[test]
922 fn fresh_self_signed_does_not_need_renewal() {
923 let cert = generate_self_signed("test.example.com").unwrap();
924 assert!(
925 !cert.needs_renewal().unwrap(),
926 "a freshly generated cert should not need renewal"
927 );
928 }
929
930 #[test]
933 fn build_tls_server_config_enforces_tls_13() {
934 let cert = generate_self_signed("test.example.com").unwrap();
935
936 let config = build_tls_server_config(&cert).unwrap();
938
939 assert!(
945 config.alpn_protocols.is_empty(),
946 "basic config should not set ALPN"
947 );
948
949 let _acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
952 }
953
954 #[test]
955 fn build_reloadable_tls_config_returns_resolver() {
956 let cert = generate_self_signed("test.example.com").unwrap();
957 let (config, resolver) = build_reloadable_tls_config(&cert).unwrap();
958
959 assert!(
961 !config.alpn_protocols.is_empty(),
962 "reloadable config should set ALPN"
963 );
964
965 let guard = resolver.inner.try_read().unwrap();
967 assert!(!guard.cert.is_empty(), "resolver should have certificates");
968 }
969
970 #[tokio::test]
973 async fn cert_resolver_update_swaps_certificate() {
974 let cert1 = generate_self_signed("one.example.com").unwrap();
975 let cert2 = generate_self_signed("two.example.com").unwrap();
976
977 let certs1 = cert1.certificate_chain_der().unwrap();
978 let key1 = cert1.private_key_der().unwrap();
979 let signing1 = rustls::crypto::ring::sign::any_supported_type(&key1).unwrap();
980 let ck1 = CertifiedKey::new(certs1.clone(), signing1);
981
982 let certs2 = cert2.certificate_chain_der().unwrap();
983 let key2 = cert2.private_key_der().unwrap();
984 let signing2 = rustls::crypto::ring::sign::any_supported_type(&key2).unwrap();
985 let ck2 = CertifiedKey::new(certs2.clone(), signing2);
986
987 let resolver = CertResolver::new(ck1);
988
989 {
991 let guard = resolver.inner.read().unwrap();
992 assert_eq!(guard.cert.len(), certs1.len());
993 }
994
995 resolver.update(ck2);
997 {
998 let guard = resolver.inner.read().unwrap();
999 assert_eq!(guard.cert.len(), certs2.len());
1000 }
1001 }
1002
1003 #[tokio::test]
1006 async fn certificate_storage_roundtrip() {
1007 let store = ProtocolStore::new_for_testing(InMemoryStorage::new());
1008 let original = generate_self_signed("roundtrip.example.com").unwrap();
1009
1010 store
1012 .store_tls_certificate(&original.certificate_chain_pem, &original.private_key_pem)
1013 .await
1014 .unwrap();
1015
1016 let (cert, key) = store.load_tls_certificate().await.unwrap().unwrap();
1018 assert_eq!(cert, original.certificate_chain_pem);
1019 assert_eq!(key, original.private_key_pem);
1020 }
1021
1022 #[tokio::test]
1023 async fn load_certificate_returns_none_when_empty() {
1024 let store = ProtocolStore::new_for_testing(InMemoryStorage::new());
1025 let result = store.load_tls_certificate().await.unwrap();
1026 assert!(result.is_none());
1027 }
1028
1029 #[tokio::test]
1032 async fn acme_challenge_router_serves_token() {
1033 use axum::body::Body;
1034 use http_body_util::BodyExt;
1035 use tower::ServiceExt;
1036
1037 let challenges = Arc::new(RwLock::new(std::collections::HashMap::new()));
1038 {
1039 let mut map = challenges.write().await;
1040 map.insert("test-token".to_owned(), "test-key-auth".to_owned());
1041 }
1042
1043 let router = acme_challenge_router(challenges);
1044
1045 let request = axum::http::Request::builder()
1047 .uri("/.well-known/acme-challenge/test-token")
1048 .body(Body::empty())
1049 .unwrap();
1050
1051 let response = router.oneshot(request).await.unwrap();
1052 assert_eq!(response.status(), axum::http::StatusCode::OK);
1053
1054 let content_type = response
1056 .headers()
1057 .get("content-type")
1058 .expect("should have Content-Type header")
1059 .to_str()
1060 .unwrap();
1061 assert_eq!(content_type, "text/plain");
1062
1063 let body = response.into_body().collect().await.unwrap().to_bytes();
1064 assert_eq!(&body[..], b"test-key-auth");
1065 }
1066
1067 #[tokio::test]
1068 async fn acme_challenge_router_returns_404_for_unknown_token() {
1069 use axum::body::Body;
1070 use tower::ServiceExt;
1071
1072 let challenges = Arc::new(RwLock::new(std::collections::HashMap::new()));
1073 let router = acme_challenge_router(challenges);
1074
1075 let request = axum::http::Request::builder()
1076 .uri("/.well-known/acme-challenge/unknown")
1077 .body(Body::empty())
1078 .unwrap();
1079
1080 let response = router.oneshot(request).await.unwrap();
1081 assert_eq!(response.status(), axum::http::StatusCode::NOT_FOUND);
1082 }
1083
1084 #[test]
1087 fn acme_provider_new_sets_defaults() {
1088 let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
1089 let provider = AcmeProvider::new("example.com", storage);
1090
1091 assert_eq!(provider.domain, "example.com");
1092 assert!(provider.email.is_none());
1093 assert!(provider.directory_url.contains("letsencrypt"));
1094 }
1095
1096 #[test]
1097 fn acme_provider_with_email() {
1098 let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
1099 let provider = AcmeProvider::new("example.com", storage).with_email("admin@example.com");
1100
1101 assert_eq!(provider.email.as_deref(), Some("admin@example.com"));
1102 }
1103
1104 #[test]
1105 fn acme_provider_with_directory_url() {
1106 let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
1107 let provider = AcmeProvider::new("example.com", storage)
1108 .with_directory_url("https://acme-staging-v02.api.letsencrypt.org/directory");
1109
1110 assert!(provider.directory_url.contains("staging"));
1111 }
1112
1113 #[test]
1114 fn acme_provider_with_cert_resolver() {
1115 let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
1116 let cert = generate_self_signed("example.com").unwrap();
1117 let certs = cert.certificate_chain_der().unwrap();
1118 let key = cert.private_key_der().unwrap();
1119 let signing = rustls::crypto::ring::sign::any_supported_type(&key).unwrap();
1120 let ck = CertifiedKey::new(certs, signing);
1121 let resolver = Arc::new(CertResolver::new(ck));
1122
1123 let provider =
1124 AcmeProvider::new("example.com", storage).with_cert_resolver(Arc::clone(&resolver));
1125
1126 assert!(provider.cert_resolver.is_some());
1127 }
1128
1129 #[test]
1132 fn acme_provider_challenges_returns_shared_map() {
1133 let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
1134 let provider = AcmeProvider::new("example.com", storage);
1135
1136 let challenges_a = provider.challenges();
1137 let challenges_b = provider.challenges();
1138
1139 assert!(Arc::ptr_eq(&challenges_a, &challenges_b));
1141 }
1142
1143 #[tokio::test]
1144 async fn acme_challenge_router_serves_from_shared_map() {
1145 use axum::body::Body;
1146 use http_body_util::BodyExt;
1147 use tower::ServiceExt;
1148
1149 let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
1152 let provider = AcmeProvider::new("example.com", storage);
1153 let challenges = provider.challenges();
1154
1155 {
1157 let mut map = challenges.write().await;
1158 map.insert("acme-token-abc".to_owned(), "key-auth-xyz".to_owned());
1159 }
1160
1161 let router = acme_challenge_router(Arc::clone(&challenges));
1163
1164 let request = axum::http::Request::builder()
1165 .uri("/.well-known/acme-challenge/acme-token-abc")
1166 .body(Body::empty())
1167 .unwrap();
1168
1169 let response = router.oneshot(request).await.unwrap();
1170 assert_eq!(response.status(), axum::http::StatusCode::OK);
1171
1172 let content_type = response
1173 .headers()
1174 .get("content-type")
1175 .expect("should have Content-Type header")
1176 .to_str()
1177 .unwrap();
1178 assert_eq!(content_type, "text/plain");
1179
1180 let body = response.into_body().collect().await.unwrap().to_bytes();
1181 assert_eq!(&body[..], b"key-auth-xyz");
1182 }
1183
1184 #[tokio::test]
1187 async fn provision_without_acme_server_returns_error() {
1188 let _ = rustls::crypto::ring::default_provider().install_default();
1191
1192 let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
1193 let provider = AcmeProvider::new("test.example.com", storage)
1194 .with_directory_url("http://127.0.0.1:1/nonexistent");
1196
1197 let result =
1198 tokio::time::timeout(std::time::Duration::from_secs(10), provider.provision()).await;
1199
1200 let provision_result = result.expect("provision() should not hang");
1202
1203 assert!(
1205 provision_result.is_err(),
1206 "provision() without ACME server should return TlsError"
1207 );
1208 }
1209
1210 #[tokio::test]
1215 async fn acme_challenge_pipeline_end_to_end() {
1216 use axum::body::Body;
1217 use http_body_util::BodyExt;
1218 use tower::ServiceExt;
1219
1220 let storage = Arc::new(ProtocolStore::new_for_testing(InMemoryStorage::new()));
1222 let provider = AcmeProvider::new("test.example.com", storage);
1223 let challenges = provider.challenges();
1224
1225 {
1228 let mut map = challenges.write().await;
1229 map.insert(
1230 "simulated-token".to_owned(),
1231 "simulated-key-auth".to_owned(),
1232 );
1233 }
1234
1235 let router = acme_challenge_router(provider.challenges());
1238
1239 let request = axum::http::Request::builder()
1241 .uri("/.well-known/acme-challenge/simulated-token")
1242 .body(Body::empty())
1243 .unwrap();
1244
1245 let response = router.clone().oneshot(request).await.unwrap();
1246 assert_eq!(response.status(), axum::http::StatusCode::OK);
1247
1248 let content_type = response
1249 .headers()
1250 .get("content-type")
1251 .expect("should have Content-Type header")
1252 .to_str()
1253 .unwrap();
1254 assert_eq!(content_type, "text/plain");
1255
1256 let body = response.into_body().collect().await.unwrap().to_bytes();
1257 assert_eq!(&body[..], b"simulated-key-auth");
1258
1259 let request_404 = axum::http::Request::builder()
1261 .uri("/.well-known/acme-challenge/unknown-token")
1262 .body(Body::empty())
1263 .unwrap();
1264 let response_404 = router.oneshot(request_404).await.unwrap();
1265 assert_eq!(response_404.status(), axum::http::StatusCode::NOT_FOUND);
1266
1267 {
1271 let mut map = challenges.write().await;
1272 map.clear();
1273 }
1274
1275 let router_after_clear = acme_challenge_router(provider.challenges());
1276 let request_cleared = axum::http::Request::builder()
1277 .uri("/.well-known/acme-challenge/simulated-token")
1278 .body(Body::empty())
1279 .unwrap();
1280 let response_cleared = router_after_clear.oneshot(request_cleared).await.unwrap();
1281 assert_eq!(
1282 response_cleared.status(),
1283 axum::http::StatusCode::NOT_FOUND,
1284 "cleared challenge map should no longer serve token"
1285 );
1286 }
1287}