1use crate::ssl::{acme::*, certificate::*};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use tokio::time::{interval, Duration};
8
9fn is_valid_domain(domain: &str) -> bool {
11 !domain.is_empty()
13 && domain.len() <= 253
14 && domain
15 .chars()
16 .all(|c| c.is_alphanumeric() || c == '.' || c == '-')
17 && !domain.starts_with('-')
18 && !domain.ends_with('-')
19 && !domain.starts_with('.')
20 && !domain.ends_with('.')
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct SslConfig {
25 pub enabled: bool,
26 pub auto_cert: bool,
27 pub cert_dir: String,
28 pub acme: Option<AcmeConfig>,
29 pub manual_certs: HashMap<String, ManualCertConfig>,
30 pub renewal_check_interval_hours: u64,
31 pub renewal_days_before_expiry: i64,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ManualCertConfig {
36 pub cert_file: String,
37 pub key_file: String,
38 pub auto_renew: bool,
39}
40
41impl Default for SslConfig {
42 fn default() -> Self {
43 Self {
44 enabled: false,
45 auto_cert: false,
46 cert_dir: "/etc/bws/certs".to_string(),
47 acme: None,
48 manual_certs: HashMap::new(),
49 renewal_check_interval_hours: 24, renewal_days_before_expiry: 30, }
52 }
53}
54
55#[derive(Debug)]
56pub struct SslManager {
57 config: SslConfig,
58 certificate_store: Arc<RwLock<CertificateStore>>,
59 acme_client: Option<Arc<RwLock<AcmeClient>>>,
60 tls_configs: Arc<RwLock<HashMap<String, rustls::ServerConfig>>>,
61}
62
63impl SslManager {
64 pub async fn new(config: SslConfig) -> Result<Self, Box<dyn std::error::Error>> {
65 ensure_certificate_directory(&config.cert_dir).await?;
67
68 let store_path = PathBuf::from(&config.cert_dir).join("certificates.toml");
70 let mut certificate_store = CertificateStore::new(store_path);
71 certificate_store.load().await?;
72
73 let acme_client = if config.auto_cert {
75 if let Some(acme_config) = &config.acme {
76 let client = AcmeClient::new(acme_config.clone());
77 Some(Arc::new(RwLock::new(client)))
78 } else {
79 return Err("ACME configuration required when auto_cert is enabled".into());
80 }
81 } else {
82 None
83 };
84
85 let manager = Self {
86 config,
87 certificate_store: Arc::new(RwLock::new(certificate_store)),
88 acme_client,
89 tls_configs: Arc::new(RwLock::new(HashMap::new())),
90 };
91
92 manager.load_certificates().await?;
94
95 Ok(manager)
96 }
97
98 pub async fn get_tls_config(&self, domain: &str) -> Option<rustls::ServerConfig> {
99 let configs = self.tls_configs.read().await;
100 configs.get(domain).cloned()
101 }
102
103 pub async fn ensure_certificate(
104 &self,
105 domain: &str,
106 ) -> Result<bool, Box<dyn std::error::Error>> {
107 {
109 let store = self.certificate_store.read().await;
110 if let Some(cert) = store.get_certificate(domain) {
111 if !cert.is_expired() && cert.validate_certificate_files().await.unwrap_or(false) {
112 log::info!("Valid certificate already exists for {}", domain);
113 return Ok(true);
114 }
115 }
116 }
117
118 if self.config.auto_cert {
120 self.obtain_certificate_via_acme(domain).await
121 } else {
122 if let Some(manual_config) = self.config.manual_certs.get(domain) {
124 self.load_manual_certificate(domain, manual_config).await
125 } else {
126 log::warn!("No certificate configuration found for domain: {}", domain);
127 Ok(false)
128 }
129 }
130 }
131
132 async fn obtain_certificate_via_acme(
133 &self,
134 domain: &str,
135 ) -> Result<bool, Box<dyn std::error::Error>> {
136 if !is_valid_domain(domain) {
137 return Err(format!("Invalid domain name: {}", domain).into());
138 }
139
140 log::info!("Obtaining certificate via ACME for domain: {}", domain);
141
142 if let Some(acme_client) = &self.acme_client {
143 let mut client = acme_client.write().await;
144 let (cert_pem, key_pem): (String, String) = client
145 .obtain_certificate(&[domain.to_string()])
146 .await
147 .map_err(|e| {
148 Box::new(std::io::Error::other(e.to_string())) as Box<dyn std::error::Error>
149 })?;
150
151 let cert_path = get_certificate_path(domain, &self.config.cert_dir);
153 let key_path = get_key_path(domain, &self.config.cert_dir);
154
155 let certificate = Certificate::from_files(
157 domain.to_string(),
158 cert_path.clone(),
159 key_path.clone(),
160 true, )
162 .await?;
163
164 certificate.save_certificate(&cert_pem, &key_pem).await?;
166
167 {
169 let mut store = self.certificate_store.write().await;
170 store.add_certificate(certificate);
171 store.save().await?;
172 }
173
174 self.update_tls_config(domain).await?;
176
177 log::info!(
178 "Successfully obtained and configured certificate for {}",
179 domain
180 );
181 Ok(true)
182 } else {
183 Err("ACME client not initialized".into())
184 }
185 }
186
187 async fn load_manual_certificate(
188 &self,
189 domain: &str,
190 manual_config: &ManualCertConfig,
191 ) -> Result<bool, Box<dyn std::error::Error>> {
192 log::info!("Loading manual certificate for domain: {}", domain);
193
194 let cert_path = PathBuf::from(&manual_config.cert_file);
195 let key_path = PathBuf::from(&manual_config.key_file);
196
197 if !cert_path.exists() {
199 return Err(format!("Certificate file not found: {}", cert_path.display()).into());
200 }
201 if !key_path.exists() {
202 return Err(format!("Key file not found: {}", key_path.display()).into());
203 }
204
205 let certificate = Certificate::from_files(
207 domain.to_string(),
208 cert_path,
209 key_path,
210 manual_config.auto_renew,
211 )
212 .await?;
213
214 if !certificate.validate_certificate_files().await? {
216 return Err(format!("Invalid certificate files for domain: {}", domain).into());
217 }
218
219 {
221 let mut store = self.certificate_store.write().await;
222 store.add_certificate(certificate);
223 store.save().await?;
224 }
225
226 self.update_tls_config(domain).await?;
228
229 log::info!("Successfully loaded manual certificate for {}", domain);
230 Ok(true)
231 }
232
233 async fn update_tls_config(&self, domain: &str) -> Result<(), Box<dyn std::error::Error>> {
234 let store = self.certificate_store.read().await;
235 if let Some(certificate) = store.get_certificate(domain) {
236 let tls_config = certificate.get_rustls_config().await?;
237 let mut configs = self.tls_configs.write().await;
238 configs.insert(domain.to_string(), tls_config);
239 log::info!("Updated TLS configuration for {}", domain);
240 }
241 Ok(())
242 }
243
244 async fn load_certificates(&self) -> Result<(), Box<dyn std::error::Error>> {
245 let certificates = {
246 let store = self.certificate_store.read().await;
247 store.list_certificates().to_vec()
248 };
249
250 for certificate in certificates {
251 if certificate
253 .validate_certificate_files()
254 .await
255 .unwrap_or(false)
256 {
257 self.update_tls_config(&certificate.domain).await?;
259 } else {
260 log::warn!(
261 "Certificate files invalid for domain: {}, will attempt renewal",
262 certificate.domain
263 );
264 }
265 }
266
267 let store = self.certificate_store.read().await;
268 let cert_count = store.list_certificates().len();
269 drop(store);
270
271 log::info!("Loaded certificates for {} domains", cert_count);
272 Ok(())
273 }
274
275 pub async fn start_renewal_monitor(self: Arc<Self>) {
276 let renewal_interval = Duration::from_secs(self.config.renewal_check_interval_hours * 3600);
277 let mut interval_timer = interval(renewal_interval);
278
279 log::info!(
280 "Starting certificate renewal monitor (check every {} hours)",
281 self.config.renewal_check_interval_hours
282 );
283
284 loop {
285 interval_timer.tick().await;
286 if let Err(e) = self.check_and_renew_certificates().await {
287 log::error!("Error during certificate renewal check: {}", e);
288 }
289 }
290 }
291
292 async fn check_and_renew_certificates(&self) -> Result<(), Box<dyn std::error::Error>> {
293 log::info!("Checking certificates for renewal");
294
295 let certificates_to_renew = {
296 let store = self.certificate_store.read().await;
297 store
298 .get_certificates_needing_renewal(self.config.renewal_days_before_expiry)
299 .into_iter()
300 .map(|cert| cert.domain.clone())
301 .collect::<Vec<_>>()
302 };
303
304 if certificates_to_renew.is_empty() {
305 log::info!("No certificates need renewal");
306 return Ok(());
307 }
308
309 log::info!(
310 "Found {} certificates that need renewal",
311 certificates_to_renew.len()
312 );
313
314 for domain in certificates_to_renew {
315 log::info!("Renewing certificate for domain: {}", domain);
316
317 {
319 let mut store = self.certificate_store.write().await;
320 store.update_renewal_check(&domain);
321 store.save().await?;
322 }
323
324 match self.renew_certificate(&domain).await {
325 Ok(()) => {
326 log::info!("Successfully renewed certificate for {}", domain);
327 }
328 Err(e) => {
329 log::error!("Failed to renew certificate for {}: {}", domain, e);
330 }
332 }
333 }
334
335 Ok(())
336 }
337
338 async fn renew_certificate(&self, domain: &str) -> Result<(), Box<dyn std::error::Error>> {
339 if self.config.auto_cert {
340 self.obtain_certificate_via_acme(domain).await?;
342 } else {
343 if let Some(manual_config) = self.config.manual_certs.get(domain) {
345 self.load_manual_certificate(domain, manual_config).await?;
346 } else {
347 return Err(format!("No renewal method available for domain: {}", domain).into());
348 }
349 }
350
351 Ok(())
352 }
353
354 pub async fn remove_certificate(
355 &self,
356 domain: &str,
357 ) -> Result<bool, Box<dyn std::error::Error>> {
358 let removed = {
359 let mut store = self.certificate_store.write().await;
360 let removed = store.remove_certificate(domain);
361 if removed {
362 store.save().await?;
363 }
364 removed
365 };
366
367 if removed {
368 let mut configs = self.tls_configs.write().await;
370 configs.remove(domain);
371 log::info!("Removed certificate for domain: {}", domain);
372 }
373
374 Ok(removed)
375 }
376
377 pub async fn list_certificates(&self) -> Vec<Certificate> {
378 let store = self.certificate_store.read().await;
379 store.list_certificates().to_vec()
380 }
381
382 pub async fn get_certificate_info(&self, domain: &str) -> Option<Certificate> {
383 let store = self.certificate_store.read().await;
384 store.get_certificate(domain).cloned()
385 }
386
387 pub fn is_ssl_enabled(&self) -> bool {
388 self.config.enabled
389 }
390
391 pub fn handles_acme_challenge(&self, path: &str) -> bool {
392 path.starts_with("/.well-known/acme-challenge/")
393 }
394
395 pub async fn get_acme_challenge_response(&self, token: &str) -> Option<String> {
396 if let Some(acme_client) = &self.acme_client {
397 let client = acme_client.read().await;
398 client.get_challenge_content(token)
399 } else {
400 None
401 }
402 }
403}
404
405impl SslConfig {
407 pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
408 if !self.enabled {
409 return Ok(());
410 }
411
412 if self.auto_cert {
413 if self.acme.is_none() {
414 return Err("ACME configuration required when auto_cert is enabled".into());
415 }
416
417 if let Some(acme_config) = &self.acme {
418 if acme_config.contact_email.is_empty() {
419 return Err("ACME email is required".into());
420 }
421 if !acme_config.terms_agreed {
422 return Err("ACME terms of service must be agreed to".into());
423 }
424 }
425 }
426
427 for (domain, manual_config) in &self.manual_certs {
428 if !is_valid_domain(domain) {
429 return Err(format!("Invalid domain name in manual_certs: {}", domain).into());
430 }
431
432 if manual_config.cert_file.is_empty() {
433 return Err(
434 format!("Certificate file path required for domain: {}", domain).into(),
435 );
436 }
437
438 if manual_config.key_file.is_empty() {
439 return Err(format!("Key file path required for domain: {}", domain).into());
440 }
441 }
442
443 if self.renewal_check_interval_hours == 0 {
444 return Err("Renewal check interval must be greater than 0".into());
445 }
446
447 if self.renewal_days_before_expiry < 1 {
448 return Err("Renewal days before expiry must be at least 1".into());
449 }
450
451 Ok(())
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn test_ssl_config_validation() {
461 let mut config = SslConfig::default();
462 assert!(config.validate().is_ok()); config.enabled = true;
465 config.auto_cert = true;
466 assert!(config.validate().is_err()); config.acme = Some(AcmeConfig::default());
469 assert!(config.validate().is_err()); config.acme.as_mut().unwrap().contact_email = "test@example.com".to_string();
472 assert!(config.validate().is_err()); config.acme.as_mut().unwrap().terms_agreed = true;
475 assert!(config.validate().is_ok()); }
477
478 #[test]
479 fn test_manual_cert_validation() {
480 let mut config = SslConfig {
481 enabled: true,
482 auto_cert: false,
483 ..Default::default()
484 };
485
486 config.manual_certs.insert(
488 "".to_string(),
489 ManualCertConfig {
490 cert_file: "cert.pem".to_string(),
491 key_file: "key.pem".to_string(),
492 auto_renew: false,
493 },
494 );
495 assert!(config.validate().is_err());
496
497 config.manual_certs.clear();
499 config.manual_certs.insert(
500 "example.com".to_string(),
501 ManualCertConfig {
502 cert_file: "".to_string(),
503 key_file: "key.pem".to_string(),
504 auto_renew: false,
505 },
506 );
507 assert!(config.validate().is_err());
508
509 config
511 .manual_certs
512 .get_mut("example.com")
513 .unwrap()
514 .cert_file = "cert.pem".to_string();
515 config.manual_certs.get_mut("example.com").unwrap().key_file = "".to_string();
516 assert!(config.validate().is_err());
517
518 config.manual_certs.get_mut("example.com").unwrap().key_file = "key.pem".to_string();
520 assert!(config.validate().is_ok());
521 }
522}