1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3#![deny(rustdoc::broken_intra_doc_links)]
4
5use std::collections::HashMap;
6use std::io;
7use std::path::{Path, PathBuf};
8use std::sync::{Arc, RwLock};
9use std::time::Duration;
10
11use instant_acme::{
12 Account, AuthorizationStatus, ChallengeType, Identifier, LetsEncrypt, NewAccount, NewOrder,
13 OrderStatus, RetryPolicy,
14};
15use mailrs_tls_reload::TlsState;
16use rustls::ServerConfig;
17use tokio::sync::watch;
18
19#[cfg(feature = "axum-http")]
20use std::net::SocketAddr;
21#[cfg(feature = "axum-http")]
22use tokio::net::TcpListener;
23
24pub type ChallengeTokens = Arc<RwLock<HashMap<String, String>>>;
32
33pub fn new_challenge_tokens() -> ChallengeTokens {
35 Arc::new(RwLock::new(HashMap::new()))
36}
37
38pub fn build_server_config(
43 cert_pem: &str,
44 key_pem: &str,
45) -> Result<ServerConfig, Box<dyn std::error::Error>> {
46 use rustls_pki_types::pem::PemObject;
47 use rustls_pki_types::{CertificateDer, PrivateKeyDer};
48
49 let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_slice_iter(cert_pem.as_bytes())
50 .collect::<Result<Vec<_>, _>>()?;
51 let key = PrivateKeyDer::from_pem_slice(key_pem.as_bytes())?;
52
53 let config = ServerConfig::builder()
54 .with_no_client_auth()
55 .with_single_cert(certs, key)?;
56
57 Ok(config)
58}
59
60pub fn cert_days_remaining(pem_data: &[u8]) -> Result<i64, Box<dyn std::error::Error>> {
67 let (_, pem) = x509_parser::pem::parse_x509_pem(pem_data)?;
68 let (_, cert) = x509_parser::parse_x509_certificate(&pem.contents)?;
69 let not_after_ts = cert.validity().not_after.timestamp();
70 let now_ts = chrono::Utc::now().timestamp();
71 Ok((not_after_ts - now_ts) / 86400)
72}
73
74pub async fn load_or_create_account(
82 email: &str,
83 staging: bool,
84 acme_dir: &Path,
85) -> Result<Account, Box<dyn std::error::Error>> {
86 let account_path = acme_dir.join("account.json");
87
88 let url = if staging {
89 LetsEncrypt::Staging.url()
90 } else {
91 LetsEncrypt::Production.url()
92 };
93
94 if account_path.exists() {
95 let data = std::fs::read_to_string(&account_path)?;
96 let credentials: instant_acme::AccountCredentials = serde_json::from_str(&data)?;
97 let account = Account::builder()?.from_credentials(credentials).await?;
98 tracing::info!(path = %account_path.display(), "acme: loaded existing account");
99 return Ok(account);
100 }
101
102 let (account, credentials) = Account::builder()?
103 .create(
104 &NewAccount {
105 contact: &[&format!("mailto:{email}")],
106 terms_of_service_agreed: true,
107 only_return_existing: false,
108 },
109 url.to_string(),
110 None,
111 )
112 .await?;
113
114 std::fs::create_dir_all(acme_dir)?;
115 std::fs::write(&account_path, serde_json::to_string_pretty(&credentials)?)?;
116 tracing::info!(path = %account_path.display(), "acme: created new account");
117
118 Ok(account)
119}
120
121pub async fn provision_cert(
131 account: &Account,
132 domains: &[String],
133 tokens: &ChallengeTokens,
134) -> Result<(String, String), Box<dyn std::error::Error>> {
135 let identifiers: Vec<Identifier> =
136 domains.iter().map(|d| Identifier::Dns(d.clone())).collect();
137
138 let mut order = account.new_order(&NewOrder::new(&identifiers)).await?;
139
140 let state = order.state();
141 if matches!(state.status, OrderStatus::Pending) {
142 let mut authz_stream = order.authorizations();
143 while let Some(result) = authz_stream.next().await {
144 let mut authz = result?;
145 match authz.status {
146 AuthorizationStatus::Pending => {}
147 AuthorizationStatus::Valid => continue,
148 _ => return Err(format!("unexpected authz status: {:?}", authz.status).into()),
149 }
150
151 let mut challenge = authz
152 .challenge(ChallengeType::Http01)
153 .ok_or("no HTTP-01 challenge found")?;
154
155 let key_auth = challenge.key_authorization();
156 {
157 let mut map = tokens.write().unwrap();
158 map.insert(challenge.token.clone(), key_auth.as_str().to_string());
159 }
160
161 challenge.set_ready().await?;
162 }
163
164 let retries = RetryPolicy::new().timeout(Duration::from_secs(120));
166 order.poll_ready(&retries).await?;
167 }
168
169 let state = order.state();
170 if matches!(state.status, OrderStatus::Ready) {
171 let key_pem = order.finalize().await?;
172 let retries = RetryPolicy::new().timeout(Duration::from_secs(60));
173 let cert_chain = order.poll_certificate(&retries).await?;
174
175 {
177 let mut map = tokens.write().unwrap();
178 map.clear();
179 }
180
181 return Ok((cert_chain, key_pem));
182 }
183
184 Err("order already valid but no key available from this path".into())
185}
186
187pub fn save_cert(acme_dir: &Path, cert_pem: &str, key_pem: &str) -> io::Result<()> {
190 std::fs::create_dir_all(acme_dir)?;
191 std::fs::write(acme_dir.join("cert.pem"), cert_pem)?;
192 std::fs::write(acme_dir.join("key.pem"), key_pem)?;
193 Ok(())
194}
195
196#[cfg(feature = "axum-http")]
207pub fn spawn_challenge_server(
208 tokens: ChallengeTokens,
209 addr: SocketAddr,
210 mut shutdown: watch::Receiver<bool>,
211) {
212 tokio::spawn(async move {
213 let listener = match TcpListener::bind(addr).await {
214 Ok(l) => l,
215 Err(e) => {
216 tracing::error!(%addr, error = %e, "acme: failed to bind challenge port");
217 return;
218 }
219 };
220 tracing::info!(%addr, "acme: challenge server listening");
221
222 let tokens = tokens.clone();
223 let app = axum::Router::new().route(
224 "/.well-known/acme-challenge/{token}",
225 axum::routing::get(
226 move |axum::extract::Path(token): axum::extract::Path<String>| {
227 let tokens = tokens.clone();
228 async move {
229 let map = tokens.read().unwrap();
230 match map.get(&token) {
231 Some(key_auth) => (
232 axum::http::StatusCode::OK,
233 [(axum::http::header::CONTENT_TYPE, "text/plain")],
234 key_auth.clone(),
235 ),
236 None => (
237 axum::http::StatusCode::NOT_FOUND,
238 [(axum::http::header::CONTENT_TYPE, "text/plain")],
239 "not found".to_string(),
240 ),
241 }
242 }
243 },
244 ),
245 );
246
247 axum::serve(listener, app)
248 .with_graceful_shutdown(async move {
249 let _ = shutdown.wait_for(|v| *v).await;
250 })
251 .await
252 .ok();
253 });
254}
255
256#[derive(Clone)]
258pub struct RenewalConfig {
259 pub domains: Vec<String>,
261 pub acme_dir: PathBuf,
263 pub check_interval: Duration,
265 pub renew_when_days_below: i64,
268}
269
270impl Default for RenewalConfig {
271 fn default() -> Self {
272 Self {
273 domains: Vec::new(),
274 acme_dir: PathBuf::from("./acme"),
275 check_interval: Duration::from_secs(12 * 60 * 60),
276 renew_when_days_below: 30,
277 }
278 }
279}
280
281pub fn spawn_renewal_task(
291 account: Account,
292 tokens: ChallengeTokens,
293 tls_state: TlsState,
294 config: RenewalConfig,
295 mut shutdown: watch::Receiver<bool>,
296) {
297 tokio::spawn(async move {
298 loop {
299 tokio::select! {
300 _ = tokio::time::sleep(config.check_interval) => {}
301 _ = shutdown.wait_for(|v| *v) => {
302 tracing::info!("acme: renewal task shutting down");
303 return;
304 }
305 }
306
307 let cert_path = config.acme_dir.join("cert.pem");
308 let cert_data = match std::fs::read(&cert_path) {
309 Ok(d) => d,
310 Err(_) => continue,
311 };
312
313 match cert_days_remaining(&cert_data) {
314 Ok(days) => {
315 tracing::info!(days, "acme: certificate expiry check");
316 if days > config.renew_when_days_below {
317 continue;
318 }
319 tracing::info!(threshold = config.renew_when_days_below, "acme: renewing");
320 }
321 Err(e) => {
322 tracing::error!(error = %e, "acme: failed to check cert expiry");
323 continue;
324 }
325 }
326
327 match provision_cert(&account, &config.domains, &tokens).await {
328 Ok((cert_pem, key_pem)) => {
329 if let Err(e) = save_cert(&config.acme_dir, &cert_pem, &key_pem) {
330 tracing::error!(error = %e, "acme: failed to save renewed cert");
331 continue;
332 }
333 match build_server_config(&cert_pem, &key_pem) {
334 Ok(server_config) => {
335 tls_state.swap(server_config);
336 tracing::info!("acme: certificate renewed and swapped");
337 }
338 Err(e) => {
339 tracing::error!(error = %e, "acme: failed to build TLS config");
340 }
341 }
342 }
343 Err(e) => {
344 tracing::error!(error = %e, "acme: renewal failed");
345 }
346 }
347 }
348 });
349}
350
351pub async fn init(
359 email: &str,
360 domains: &[String],
361 acme_dir: &Path,
362 staging: bool,
363 tokens: &ChallengeTokens,
364) -> Result<(TlsState, Account), Box<dyn std::error::Error>> {
365 let account = load_or_create_account(email, staging, acme_dir).await?;
366
367 let cert_path = acme_dir.join("cert.pem");
368 let key_path = acme_dir.join("key.pem");
369
370 let (cert_pem, key_pem) = if cert_path.exists() && key_path.exists() {
371 let cert_data = std::fs::read(&cert_path)?;
372 let days = cert_days_remaining(&cert_data).unwrap_or(0);
373 if days > 0 {
374 tracing::info!(days, "acme: existing certificate valid");
375 let cert = std::fs::read_to_string(&cert_path)?;
376 let key = std::fs::read_to_string(&key_path)?;
377 (cert, key)
378 } else {
379 tracing::info!("acme: existing certificate expired, provisioning new one");
380 let (cert, key) = provision_cert(&account, domains, tokens).await?;
381 save_cert(acme_dir, &cert, &key)?;
382 (cert, key)
383 }
384 } else {
385 tracing::info!("acme: no existing certificate, provisioning");
386 let (cert, key) = provision_cert(&account, domains, tokens).await?;
387 save_cert(acme_dir, &cert, &key)?;
388 (cert, key)
389 };
390
391 let config = build_server_config(&cert_pem, &key_pem)?;
392 let tls_state = TlsState::new(config);
393
394 Ok((tls_state, account))
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn new_challenge_tokens_starts_empty() {
403 let t = new_challenge_tokens();
404 assert!(t.read().unwrap().is_empty());
405 }
406
407 #[test]
408 fn challenge_tokens_insert_and_read() {
409 let t = new_challenge_tokens();
410 {
411 let mut map = t.write().unwrap();
412 map.insert("tok1".into(), "key_auth_1".into());
413 map.insert("tok2".into(), "key_auth_2".into());
414 }
415 let map = t.read().unwrap();
416 assert_eq!(map.get("tok1").map(String::as_str), Some("key_auth_1"));
417 assert_eq!(map.get("tok2").map(String::as_str), Some("key_auth_2"));
418 assert!(map.get("missing").is_none());
419 }
420
421 #[test]
422 fn renewal_config_default() {
423 let c = RenewalConfig::default();
424 assert!(c.domains.is_empty());
425 assert_eq!(c.acme_dir, PathBuf::from("./acme"));
426 assert_eq!(c.check_interval, Duration::from_secs(12 * 60 * 60));
427 assert_eq!(c.renew_when_days_below, 30);
428 }
429
430 #[test]
431 fn renewal_config_clone() {
432 let a = RenewalConfig {
433 domains: vec!["example.com".into()],
434 acme_dir: PathBuf::from("/var/acme"),
435 check_interval: Duration::from_secs(3600),
436 renew_when_days_below: 14,
437 };
438 let b = a.clone();
439 assert_eq!(a.domains, b.domains);
440 assert_eq!(a.acme_dir, b.acme_dir);
441 assert_eq!(a.check_interval, b.check_interval);
442 assert_eq!(a.renew_when_days_below, b.renew_when_days_below);
443 }
444
445 #[test]
446 fn cert_days_remaining_rejects_garbage() {
447 let r = cert_days_remaining(b"not a PEM");
448 assert!(r.is_err());
449 }
450
451 #[test]
452 fn cert_days_remaining_rejects_empty() {
453 let r = cert_days_remaining(b"");
454 assert!(r.is_err());
455 }
456
457 #[test]
458 fn save_cert_creates_directory() {
459 let dir = std::env::temp_dir().join(format!("mailrs-acme-test-{}", std::process::id()));
460 let _ = std::fs::remove_dir_all(&dir);
461 save_cert(&dir, "cert-pem-data", "key-pem-data").expect("save_cert");
462 assert!(dir.join("cert.pem").exists());
463 assert!(dir.join("key.pem").exists());
464 assert_eq!(
465 std::fs::read_to_string(dir.join("cert.pem")).unwrap(),
466 "cert-pem-data"
467 );
468 let _ = std::fs::remove_dir_all(&dir);
469 }
470
471 #[test]
472 fn save_cert_overwrites_existing() {
473 let dir = std::env::temp_dir().join(format!("mailrs-acme-overwrite-{}", std::process::id()));
474 let _ = std::fs::remove_dir_all(&dir);
475 save_cert(&dir, "old", "old-key").unwrap();
476 save_cert(&dir, "new", "new-key").unwrap();
477 assert_eq!(std::fs::read_to_string(dir.join("cert.pem")).unwrap(), "new");
478 let _ = std::fs::remove_dir_all(&dir);
479 }
480
481 #[test]
482 fn cert_days_remaining_with_expired_pem() {
483 let expired_pem = b"-----BEGIN CERTIFICATE-----
486MIIBxDCCAW6gAwIBAgIUC2DnZmnxR6c6PXcyG9hqOQRJxMUwDQYJKoZIhvcNAQEL
487BQAwGTEXMBUGA1UEAwwOZXhwaXJlZC50ZXN0LjAeFw0wMDAxMDEwMDAwMDBaFw0w
488MDAxMDIwMDAwMDBaMBkxFzAVBgNVBAMMDmV4cGlyZWQudGVzdC4wXDANBgkqhkiG
4899w0BAQEFAANLADBIAkEAuGVc7uoEgavLxc7KVxSi5q6IXkD0pAYmqr8gbZIO5p2k
490KqQXNkVtoyzMOXjlV6vLOXAcgksMQQ5UqxQwlmHvOQIDAQABo1MwUTAdBgNVHQ4E
491FgQUw7VxpcfRPwOOTQ6SHGyqyhI/o/owHwYDVR0jBBgwFoAUw7VxpcfRPwOOTQ6S
492HGyqyhI/o/owDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAANBAGc=
493-----END CERTIFICATE-----";
494 let r = cert_days_remaining(expired_pem);
498 match r {
501 Ok(days) => assert!(days < 0, "expected negative days, got {days}"),
502 Err(_) => {} }
504 }
505
506 #[test]
507 fn renewal_config_can_be_constructed_with_custom_values() {
508 let c = RenewalConfig {
509 domains: vec!["a.com".into(), "b.com".into()],
510 acme_dir: PathBuf::from("/etc/acme"),
511 check_interval: Duration::from_secs(60 * 60),
512 renew_when_days_below: 7,
513 };
514 assert_eq!(c.domains.len(), 2);
515 assert_eq!(c.renew_when_days_below, 7);
516 assert_eq!(c.check_interval, Duration::from_secs(3600));
517 }
518
519 #[test]
520 fn challenge_tokens_default_works() {
521 let t: ChallengeTokens = Default::default();
524 assert!(t.read().unwrap().is_empty());
525 }
526
527 #[test]
528 fn challenge_tokens_clear() {
529 let t = new_challenge_tokens();
530 {
531 let mut map = t.write().unwrap();
532 map.insert("a".into(), "1".into());
533 map.insert("b".into(), "2".into());
534 }
535 {
536 let mut map = t.write().unwrap();
537 map.clear();
538 }
539 assert!(t.read().unwrap().is_empty());
540 }
541
542 #[test]
543 fn save_cert_preserves_exact_bytes() {
544 let dir = std::env::temp_dir().join(format!("mailrs-acme-bytes-{}", std::process::id()));
545 let _ = std::fs::remove_dir_all(&dir);
546 let cert = "cert-with-special-chars\n\t\r\nß";
547 let key = "key-with-newlines\n\n\n";
548 save_cert(&dir, cert, key).unwrap();
549 assert_eq!(std::fs::read_to_string(dir.join("cert.pem")).unwrap(), cert);
550 assert_eq!(std::fs::read_to_string(dir.join("key.pem")).unwrap(), key);
551 let _ = std::fs::remove_dir_all(&dir);
552 }
553
554 #[test]
555 fn build_server_config_rejects_garbage() {
556 let r = build_server_config("not pem", "not key either");
557 assert!(r.is_err());
558 }
559}