bws_web_server/ssl/
manager.rs

1use crate::ssl::{acme::*, certificate::*};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use tokio::time::{interval, Duration};
8
9// Utility function to validate domain names
10fn is_valid_domain(domain: &str) -> bool {
11    // Basic domain validation
12    !domain.is_empty()
13        && domain.len() <= 253
14        && domain
15            .chars()
16            .all(|c| c.is_alphanumeric() || c == '.' || c == '-')
17        && !domain.starts_with('-')
18        && !domain.ends_with('-')
19        && !domain.starts_with('.')
20        && !domain.ends_with('.')
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct SslConfig {
25    pub enabled: bool,
26    pub auto_cert: bool,
27    pub cert_dir: String,
28    pub acme: Option<AcmeConfig>,
29    pub manual_certs: HashMap<String, ManualCertConfig>,
30    pub renewal_check_interval_hours: u64,
31    pub renewal_days_before_expiry: i64,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ManualCertConfig {
36    pub cert_file: String,
37    pub key_file: String,
38    pub auto_renew: bool,
39}
40
41impl Default for SslConfig {
42    fn default() -> Self {
43        Self {
44            enabled: false,
45            auto_cert: false,
46            cert_dir: "/etc/bws/certs".to_string(),
47            acme: None,
48            manual_certs: HashMap::new(),
49            renewal_check_interval_hours: 24, // Check daily
50            renewal_days_before_expiry: 30,   // Renew 30 days before expiry
51        }
52    }
53}
54
55#[derive(Debug)]
56pub struct SslManager {
57    config: SslConfig,
58    certificate_store: Arc<RwLock<CertificateStore>>,
59    acme_client: Option<Arc<RwLock<AcmeClient>>>,
60    tls_configs: Arc<RwLock<HashMap<String, rustls::ServerConfig>>>,
61}
62
63impl SslManager {
64    pub async fn new(config: SslConfig) -> Result<Self, Box<dyn std::error::Error>> {
65        // Ensure certificate directory exists
66        ensure_certificate_directory(&config.cert_dir).await?;
67
68        // Initialize certificate store
69        let store_path = PathBuf::from(&config.cert_dir).join("certificates.toml");
70        let mut certificate_store = CertificateStore::new(store_path);
71        certificate_store.load().await?;
72
73        // Initialize ACME client if auto_cert is enabled
74        let acme_client = if config.auto_cert {
75            if let Some(acme_config) = &config.acme {
76                let client = AcmeClient::new(acme_config.clone());
77                Some(Arc::new(RwLock::new(client)))
78            } else {
79                return Err("ACME configuration required when auto_cert is enabled".into());
80            }
81        } else {
82            None
83        };
84
85        let manager = Self {
86            config,
87            certificate_store: Arc::new(RwLock::new(certificate_store)),
88            acme_client,
89            tls_configs: Arc::new(RwLock::new(HashMap::new())),
90        };
91
92        // Load existing certificates
93        manager.load_certificates().await?;
94
95        Ok(manager)
96    }
97
98    pub async fn get_tls_config(&self, domain: &str) -> Option<rustls::ServerConfig> {
99        let configs = self.tls_configs.read().await;
100        configs.get(domain).cloned()
101    }
102
103    pub async fn ensure_certificate(
104        &self,
105        domain: &str,
106    ) -> Result<bool, Box<dyn std::error::Error>> {
107        // Check if we already have a valid certificate
108        {
109            let store = self.certificate_store.read().await;
110            if let Some(cert) = store.get_certificate(domain) {
111                if !cert.is_expired() && cert.validate_certificate_files().await.unwrap_or(false) {
112                    log::info!("Valid certificate already exists for {}", domain);
113                    return Ok(true);
114                }
115            }
116        }
117
118        // Try to obtain certificate
119        if self.config.auto_cert {
120            self.obtain_certificate_via_acme(domain).await
121        } else {
122            // Check manual certificate configuration
123            if let Some(manual_config) = self.config.manual_certs.get(domain) {
124                self.load_manual_certificate(domain, manual_config).await
125            } else {
126                log::warn!("No certificate configuration found for domain: {}", domain);
127                Ok(false)
128            }
129        }
130    }
131
132    async fn obtain_certificate_via_acme(
133        &self,
134        domain: &str,
135    ) -> Result<bool, Box<dyn std::error::Error>> {
136        if !is_valid_domain(domain) {
137            return Err(format!("Invalid domain name: {}", domain).into());
138        }
139
140        log::info!("Obtaining certificate via ACME for domain: {}", domain);
141
142        if let Some(acme_client) = &self.acme_client {
143            let mut client = acme_client.write().await;
144            let (cert_pem, key_pem): (String, String) = client
145                .obtain_certificate(&[domain.to_string()])
146                .await
147                .map_err(|e| {
148                    Box::new(std::io::Error::other(e.to_string())) as Box<dyn std::error::Error>
149                })?;
150
151            // Create certificate paths
152            let cert_path = get_certificate_path(domain, &self.config.cert_dir);
153            let key_path = get_key_path(domain, &self.config.cert_dir);
154
155            // Create certificate object
156            let certificate = Certificate::from_files(
157                domain.to_string(),
158                cert_path.clone(),
159                key_path.clone(),
160                true, // auto_renew enabled for ACME certificates
161            )
162            .await?;
163
164            // Save certificate files
165            certificate.save_certificate(&cert_pem, &key_pem).await?;
166
167            // Add to store
168            {
169                let mut store = self.certificate_store.write().await;
170                store.add_certificate(certificate);
171                store.save().await?;
172            }
173
174            // Update TLS config
175            self.update_tls_config(domain).await?;
176
177            log::info!(
178                "Successfully obtained and configured certificate for {}",
179                domain
180            );
181            Ok(true)
182        } else {
183            Err("ACME client not initialized".into())
184        }
185    }
186
187    async fn load_manual_certificate(
188        &self,
189        domain: &str,
190        manual_config: &ManualCertConfig,
191    ) -> Result<bool, Box<dyn std::error::Error>> {
192        log::info!("Loading manual certificate for domain: {}", domain);
193
194        let cert_path = PathBuf::from(&manual_config.cert_file);
195        let key_path = PathBuf::from(&manual_config.key_file);
196
197        // Validate certificate files exist and are readable
198        if !cert_path.exists() {
199            return Err(format!("Certificate file not found: {}", cert_path.display()).into());
200        }
201        if !key_path.exists() {
202            return Err(format!("Key file not found: {}", key_path.display()).into());
203        }
204
205        // Create certificate object
206        let certificate = Certificate::from_files(
207            domain.to_string(),
208            cert_path,
209            key_path,
210            manual_config.auto_renew,
211        )
212        .await?;
213
214        // Validate certificate files
215        if !certificate.validate_certificate_files().await? {
216            return Err(format!("Invalid certificate files for domain: {}", domain).into());
217        }
218
219        // Add to store
220        {
221            let mut store = self.certificate_store.write().await;
222            store.add_certificate(certificate);
223            store.save().await?;
224        }
225
226        // Update TLS config
227        self.update_tls_config(domain).await?;
228
229        log::info!("Successfully loaded manual certificate for {}", domain);
230        Ok(true)
231    }
232
233    async fn update_tls_config(&self, domain: &str) -> Result<(), Box<dyn std::error::Error>> {
234        let store = self.certificate_store.read().await;
235        if let Some(certificate) = store.get_certificate(domain) {
236            let tls_config = certificate.get_rustls_config().await?;
237            let mut configs = self.tls_configs.write().await;
238            configs.insert(domain.to_string(), tls_config);
239            log::info!("Updated TLS configuration for {}", domain);
240        }
241        Ok(())
242    }
243
244    async fn load_certificates(&self) -> Result<(), Box<dyn std::error::Error>> {
245        let certificates = {
246            let store = self.certificate_store.read().await;
247            store.list_certificates().to_vec()
248        };
249
250        for certificate in certificates {
251            // Validate certificate files
252            if certificate
253                .validate_certificate_files()
254                .await
255                .unwrap_or(false)
256            {
257                // Update TLS config for valid certificates
258                self.update_tls_config(&certificate.domain).await?;
259            } else {
260                log::warn!(
261                    "Certificate files invalid for domain: {}, will attempt renewal",
262                    certificate.domain
263                );
264            }
265        }
266
267        let store = self.certificate_store.read().await;
268        let cert_count = store.list_certificates().len();
269        drop(store);
270
271        log::info!("Loaded certificates for {} domains", cert_count);
272        Ok(())
273    }
274
275    pub async fn start_renewal_monitor(self: Arc<Self>) {
276        let renewal_interval = Duration::from_secs(self.config.renewal_check_interval_hours * 3600);
277        let mut interval_timer = interval(renewal_interval);
278
279        log::info!(
280            "Starting certificate renewal monitor (check every {} hours)",
281            self.config.renewal_check_interval_hours
282        );
283
284        loop {
285            interval_timer.tick().await;
286            if let Err(e) = self.check_and_renew_certificates().await {
287                log::error!("Error during certificate renewal check: {}", e);
288            }
289        }
290    }
291
292    async fn check_and_renew_certificates(&self) -> Result<(), Box<dyn std::error::Error>> {
293        log::info!("Checking certificates for renewal");
294
295        let certificates_to_renew = {
296            let store = self.certificate_store.read().await;
297            store
298                .get_certificates_needing_renewal(self.config.renewal_days_before_expiry)
299                .into_iter()
300                .map(|cert| cert.domain.clone())
301                .collect::<Vec<_>>()
302        };
303
304        if certificates_to_renew.is_empty() {
305            log::info!("No certificates need renewal");
306            return Ok(());
307        }
308
309        log::info!(
310            "Found {} certificates that need renewal",
311            certificates_to_renew.len()
312        );
313
314        for domain in certificates_to_renew {
315            log::info!("Renewing certificate for domain: {}", domain);
316
317            // Update renewal check timestamp
318            {
319                let mut store = self.certificate_store.write().await;
320                store.update_renewal_check(&domain);
321                store.save().await?;
322            }
323
324            match self.renew_certificate(&domain).await {
325                Ok(()) => {
326                    log::info!("Successfully renewed certificate for {}", domain);
327                }
328                Err(e) => {
329                    log::error!("Failed to renew certificate for {}: {}", domain, e);
330                    // Continue with other certificates even if one fails
331                }
332            }
333        }
334
335        Ok(())
336    }
337
338    async fn renew_certificate(&self, domain: &str) -> Result<(), Box<dyn std::error::Error>> {
339        if self.config.auto_cert {
340            // For ACME certificates, obtain a new certificate
341            self.obtain_certificate_via_acme(domain).await?;
342        } else {
343            // For manual certificates, reload from files (in case they were updated)
344            if let Some(manual_config) = self.config.manual_certs.get(domain) {
345                self.load_manual_certificate(domain, manual_config).await?;
346            } else {
347                return Err(format!("No renewal method available for domain: {}", domain).into());
348            }
349        }
350
351        Ok(())
352    }
353
354    pub async fn remove_certificate(
355        &self,
356        domain: &str,
357    ) -> Result<bool, Box<dyn std::error::Error>> {
358        let removed = {
359            let mut store = self.certificate_store.write().await;
360            let removed = store.remove_certificate(domain);
361            if removed {
362                store.save().await?;
363            }
364            removed
365        };
366
367        if removed {
368            // Remove from TLS configs
369            let mut configs = self.tls_configs.write().await;
370            configs.remove(domain);
371            log::info!("Removed certificate for domain: {}", domain);
372        }
373
374        Ok(removed)
375    }
376
377    pub async fn list_certificates(&self) -> Vec<Certificate> {
378        let store = self.certificate_store.read().await;
379        store.list_certificates().to_vec()
380    }
381
382    pub async fn get_certificate_info(&self, domain: &str) -> Option<Certificate> {
383        let store = self.certificate_store.read().await;
384        store.get_certificate(domain).cloned()
385    }
386
387    pub fn is_ssl_enabled(&self) -> bool {
388        self.config.enabled
389    }
390
391    pub fn handles_acme_challenge(&self, path: &str) -> bool {
392        path.starts_with("/.well-known/acme-challenge/")
393    }
394
395    pub async fn get_acme_challenge_response(&self, token: &str) -> Option<String> {
396        if let Some(acme_client) = &self.acme_client {
397            let client = acme_client.read().await;
398            client.get_challenge_content(token)
399        } else {
400            None
401        }
402    }
403}
404
405// Configuration validation
406impl SslConfig {
407    pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
408        if !self.enabled {
409            return Ok(());
410        }
411
412        if self.auto_cert {
413            if self.acme.is_none() {
414                return Err("ACME configuration required when auto_cert is enabled".into());
415            }
416
417            if let Some(acme_config) = &self.acme {
418                if acme_config.contact_email.is_empty() {
419                    return Err("ACME email is required".into());
420                }
421                if !acme_config.terms_agreed {
422                    return Err("ACME terms of service must be agreed to".into());
423                }
424            }
425        }
426
427        for (domain, manual_config) in &self.manual_certs {
428            if !is_valid_domain(domain) {
429                return Err(format!("Invalid domain name in manual_certs: {}", domain).into());
430            }
431
432            if manual_config.cert_file.is_empty() {
433                return Err(
434                    format!("Certificate file path required for domain: {}", domain).into(),
435                );
436            }
437
438            if manual_config.key_file.is_empty() {
439                return Err(format!("Key file path required for domain: {}", domain).into());
440            }
441        }
442
443        if self.renewal_check_interval_hours == 0 {
444            return Err("Renewal check interval must be greater than 0".into());
445        }
446
447        if self.renewal_days_before_expiry < 1 {
448            return Err("Renewal days before expiry must be at least 1".into());
449        }
450
451        Ok(())
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    #[test]
460    fn test_ssl_config_validation() {
461        let mut config = SslConfig::default();
462        assert!(config.validate().is_ok()); // Disabled SSL should be valid
463
464        config.enabled = true;
465        config.auto_cert = true;
466        assert!(config.validate().is_err()); // Missing ACME config
467
468        config.acme = Some(AcmeConfig::default());
469        assert!(config.validate().is_err()); // Empty email
470
471        config.acme.as_mut().unwrap().contact_email = "test@example.com".to_string();
472        assert!(config.validate().is_err()); // Terms not agreed
473
474        config.acme.as_mut().unwrap().terms_agreed = true;
475        assert!(config.validate().is_ok()); // Should be valid now
476    }
477
478    #[test]
479    fn test_manual_cert_validation() {
480        let mut config = SslConfig {
481            enabled: true,
482            auto_cert: false,
483            ..Default::default()
484        };
485
486        // Add invalid domain
487        config.manual_certs.insert(
488            "".to_string(),
489            ManualCertConfig {
490                cert_file: "cert.pem".to_string(),
491                key_file: "key.pem".to_string(),
492                auto_renew: false,
493            },
494        );
495        assert!(config.validate().is_err());
496
497        // Fix domain but empty cert file
498        config.manual_certs.clear();
499        config.manual_certs.insert(
500            "example.com".to_string(),
501            ManualCertConfig {
502                cert_file: "".to_string(),
503                key_file: "key.pem".to_string(),
504                auto_renew: false,
505            },
506        );
507        assert!(config.validate().is_err());
508
509        // Fix cert file but empty key file
510        config
511            .manual_certs
512            .get_mut("example.com")
513            .unwrap()
514            .cert_file = "cert.pem".to_string();
515        config.manual_certs.get_mut("example.com").unwrap().key_file = "".to_string();
516        assert!(config.validate().is_err());
517
518        // Fix key file - should be valid
519        config.manual_certs.get_mut("example.com").unwrap().key_file = "key.pem".to_string();
520        assert!(config.validate().is_ok());
521    }
522}