bws_web_server/ssl/
certificate.rs

1use chrono::{DateTime, Utc};
2use rustls_pemfile::{certs, private_key};
3use serde::{Deserialize, Serialize};
4use std::io::BufReader;
5use std::path::PathBuf;
6use tokio::fs;
7use x509_parser::prelude::*;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Certificate {
11    pub domain: String,
12    pub cert_path: PathBuf,
13    pub key_path: PathBuf,
14    pub issued_at: DateTime<Utc>,
15    pub expires_at: DateTime<Utc>,
16    pub issuer: String,
17    pub san_domains: Vec<String>,
18    pub auto_renew: bool,
19    pub last_renewal_check: Option<DateTime<Utc>>,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CertificateStore {
24    pub certificates: Vec<Certificate>,
25    pub storage_path: PathBuf,
26}
27
28impl Certificate {
29    pub async fn from_files(
30        domain: String,
31        cert_path: PathBuf,
32        key_path: PathBuf,
33        auto_renew: bool,
34    ) -> Result<Self, Box<dyn std::error::Error>> {
35        // Read certificate file
36        let cert_data = fs::read(&cert_path).await?;
37        let cert_info = Self::parse_certificate(&cert_data)?;
38
39        Ok(Certificate {
40            domain,
41            cert_path,
42            key_path,
43            issued_at: cert_info.issued_at,
44            expires_at: cert_info.expires_at,
45            issuer: cert_info.issuer,
46            san_domains: cert_info.san_domains,
47            auto_renew,
48            last_renewal_check: None,
49        })
50    }
51
52    pub fn from_pem_data(
53        domain: String,
54        cert_path: PathBuf,
55        key_path: PathBuf,
56        cert_pem: &str,
57        auto_renew: bool,
58    ) -> Result<Self, Box<dyn std::error::Error>> {
59        // Parse certificate from PEM data
60        let cert_info = Self::parse_certificate(cert_pem.as_bytes())?;
61
62        Ok(Certificate {
63            domain,
64            cert_path,
65            key_path,
66            issued_at: cert_info.issued_at,
67            expires_at: cert_info.expires_at,
68            issuer: cert_info.issuer,
69            san_domains: cert_info.san_domains,
70            auto_renew,
71            last_renewal_check: None,
72        })
73    }
74
75    pub async fn save_certificate(
76        &self,
77        cert_pem: &str,
78        key_pem: &str,
79    ) -> Result<(), Box<dyn std::error::Error>> {
80        // Ensure parent directories exist
81        if let Some(parent) = self.cert_path.parent() {
82            fs::create_dir_all(parent).await?;
83        }
84        if let Some(parent) = self.key_path.parent() {
85            fs::create_dir_all(parent).await?;
86        }
87
88        // Write certificate and key files
89        fs::write(&self.cert_path, cert_pem).await?;
90        fs::write(&self.key_path, key_pem).await?;
91
92        // Set appropriate permissions (600 for key file)
93        #[cfg(unix)]
94        {
95            use std::os::unix::fs::PermissionsExt;
96            let key_perms = std::fs::Permissions::from_mode(0o600);
97            std::fs::set_permissions(&self.key_path, key_perms)?;
98
99            let cert_perms = std::fs::Permissions::from_mode(0o644);
100            std::fs::set_permissions(&self.cert_path, cert_perms)?;
101        }
102
103        log::info!(
104            "Certificate saved for {} at {} and {}",
105            self.domain,
106            self.cert_path.display(),
107            self.key_path.display()
108        );
109
110        Ok(())
111    }
112
113    pub fn days_until_expiry(&self) -> i64 {
114        let now = Utc::now();
115        (self.expires_at - now).num_days()
116    }
117
118    pub fn needs_renewal(&self, days_before_expiry: i64) -> bool {
119        self.auto_renew && self.days_until_expiry() <= days_before_expiry
120    }
121
122    pub fn is_expired(&self) -> bool {
123        Utc::now() > self.expires_at
124    }
125
126    pub fn covers_domain(&self, domain: &str) -> bool {
127        self.domain == domain || self.san_domains.contains(&domain.to_string())
128    }
129
130    fn parse_certificate(cert_data: &[u8]) -> Result<CertificateInfo, Box<dyn std::error::Error>> {
131        let mut reader = BufReader::new(cert_data);
132        let certs = certs(&mut reader)
133            .collect::<Result<Vec<_>, _>>()
134            .map_err(|e| format!("Failed to parse certificate: {e}"))?;
135
136        if certs.is_empty() {
137            return Err("No certificates found in file".into());
138        }
139
140        let cert = &certs[0];
141        let (_, parsed_cert) = X509Certificate::from_der(cert.as_ref())
142            .map_err(|e| format!("Failed to parse X509 certificate: {e}"))?;
143
144        let issued_at = DateTime::from_timestamp(parsed_cert.validity().not_before.timestamp(), 0)
145            .unwrap_or_else(Utc::now);
146
147        let expires_at = DateTime::from_timestamp(parsed_cert.validity().not_after.timestamp(), 0)
148            .unwrap_or_else(|| Utc::now() + chrono::Duration::days(90));
149
150        let issuer = parsed_cert
151            .issuer()
152            .iter_common_name()
153            .next()
154            .and_then(|cn| cn.as_str().ok())
155            .unwrap_or("Unknown")
156            .to_string();
157
158        // SAN parsing would require more complex ASN.1 parsing
159        // For now, we rely on subject CN which covers most use cases
160        // Multi-domain certificates would need proper x509-parser SAN extension parsing
161        let san_domains = Vec::new();
162        log::debug!(
163            "Using subject CN for certificate validation; SAN extension parsing not implemented"
164        );
165
166        Ok(CertificateInfo {
167            issued_at,
168            expires_at,
169            issuer,
170            san_domains,
171        })
172    }
173
174    pub async fn validate_certificate_files(&self) -> Result<bool, Box<dyn std::error::Error>> {
175        // Check if files exist
176        if !self.cert_path.exists() || !self.key_path.exists() {
177            return Ok(false);
178        }
179
180        // Try to load certificate
181        let cert_data = fs::read(&self.cert_path).await?;
182        let mut cert_reader = BufReader::new(cert_data.as_slice());
183        let certs_result = certs(&mut cert_reader).collect::<Result<Vec<_>, _>>();
184        if certs_result.is_err() {
185            return Ok(false);
186        }
187
188        // Try to load private key
189        let key_data = fs::read(&self.key_path).await?;
190        let mut key_reader = BufReader::new(key_data.as_slice());
191        let key_result = private_key(&mut key_reader);
192        if key_result.is_err() {
193            return Ok(false);
194        }
195
196        Ok(true)
197    }
198
199    pub async fn get_rustls_config(
200        &self,
201    ) -> Result<rustls::ServerConfig, Box<dyn std::error::Error>> {
202        // Load certificate chain
203        let cert_data = fs::read(&self.cert_path).await?;
204        let mut cert_reader = BufReader::new(cert_data.as_slice());
205        let cert_chain = certs(&mut cert_reader)
206            .collect::<Result<Vec<_>, _>>()
207            .map_err(|e| format!("Failed to load certificate: {e}"))?;
208
209        // Load private key
210        let key_data = fs::read(&self.key_path).await?;
211        let mut key_reader = BufReader::new(key_data.as_slice());
212        let private_key = private_key(&mut key_reader)
213            .map_err(|e| format!("Failed to load private key: {e}"))?
214            .ok_or("No private key found")?;
215
216        // Create rustls config
217        let config = rustls::ServerConfig::builder()
218            .with_no_client_auth()
219            .with_single_cert(cert_chain, private_key)
220            .map_err(|e| format!("Invalid certificate/key: {e}"))?;
221
222        Ok(config)
223    }
224}
225
226impl CertificateStore {
227    pub fn new(storage_path: PathBuf) -> Self {
228        Self {
229            certificates: Vec::new(),
230            storage_path,
231        }
232    }
233
234    pub async fn load(&mut self) -> Result<(), Box<dyn std::error::Error>> {
235        if !self.storage_path.exists() {
236            log::info!("Certificate store file not found, starting with empty store");
237            return Ok(());
238        }
239
240        let data = fs::read_to_string(&self.storage_path).await?;
241        let store: CertificateStore = toml::from_str(&data)?;
242        self.certificates = store.certificates;
243
244        log::info!("Loaded {} certificates from store", self.certificates.len());
245        Ok(())
246    }
247
248    pub async fn save(&self) -> Result<(), Box<dyn std::error::Error>> {
249        let data = toml::to_string_pretty(self)?;
250        if let Some(parent) = self.storage_path.parent() {
251            fs::create_dir_all(parent).await?;
252        }
253        fs::write(&self.storage_path, data).await?;
254        log::info!(
255            "Saved certificate store with {} certificates",
256            self.certificates.len()
257        );
258        Ok(())
259    }
260
261    pub fn add_certificate(&mut self, certificate: Certificate) {
262        // Remove any existing certificate for the same domain
263        self.certificates
264            .retain(|cert| cert.domain != certificate.domain);
265        self.certificates.push(certificate);
266    }
267
268    pub fn get_certificate(&self, domain: &str) -> Option<&Certificate> {
269        self.certificates
270            .iter()
271            .find(|cert| cert.covers_domain(domain))
272    }
273
274    pub fn has_certificate(&self, domain: &str) -> bool {
275        self.certificates
276            .iter()
277            .any(|cert| cert.covers_domain(domain))
278    }
279
280    pub fn get_certificates_needing_renewal(&self, days_before_expiry: i64) -> Vec<&Certificate> {
281        self.certificates
282            .iter()
283            .filter(|cert| cert.needs_renewal(days_before_expiry))
284            .collect()
285    }
286
287    pub fn get_expired_certificates(&self) -> Vec<&Certificate> {
288        self.certificates
289            .iter()
290            .filter(|cert| cert.is_expired())
291            .collect()
292    }
293
294    pub fn update_renewal_check(&mut self, domain: &str) {
295        if let Some(cert) = self.certificates.iter_mut().find(|c| c.domain == domain) {
296            cert.last_renewal_check = Some(Utc::now());
297        }
298    }
299
300    pub fn remove_certificate(&mut self, domain: &str) -> bool {
301        let original_len = self.certificates.len();
302        self.certificates.retain(|cert| cert.domain != domain);
303        self.certificates.len() != original_len
304    }
305
306    pub fn list_certificates(&self) -> &[Certificate] {
307        &self.certificates
308    }
309
310    /// Get all domains managed by this certificate store
311    pub fn get_all_domains(&self) -> Vec<String> {
312        self.certificates
313            .iter()
314            .map(|cert| cert.domain.clone())
315            .collect()
316    }
317
318    /// Get certificate expiry date for a domain
319    pub fn get_certificate_expiry(
320        &self,
321        domain: &str,
322    ) -> Result<Option<chrono::DateTime<chrono::Utc>>, Box<dyn std::error::Error + Send + Sync>>
323    {
324        if let Some(cert) = self.certificates.iter().find(|c| c.domain == domain) {
325            Ok(Some(cert.expires_at))
326        } else {
327            Ok(None)
328        }
329    }
330}
331
332#[derive(Debug)]
333struct CertificateInfo {
334    issued_at: DateTime<Utc>,
335    expires_at: DateTime<Utc>,
336    issuer: String,
337    san_domains: Vec<String>,
338}
339
340// Helper functions for certificate management
341pub fn get_certificate_path(domain: &str, cert_dir: &str) -> PathBuf {
342    PathBuf::from(cert_dir).join(format!("{domain}.crt"))
343}
344
345pub fn get_key_path(domain: &str, cert_dir: &str) -> PathBuf {
346    PathBuf::from(cert_dir).join(format!("{domain}.key"))
347}
348
349pub async fn ensure_certificate_directory(
350    cert_dir: &str,
351) -> Result<(), Box<dyn std::error::Error>> {
352    let path = PathBuf::from(cert_dir);
353    if !path.exists() {
354        fs::create_dir_all(&path).await?;
355        log::info!("Created certificate directory: {}", path.display());
356    }
357    Ok(())
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use tempfile::tempdir;
364
365    #[tokio::test]
366    async fn test_certificate_store() {
367        let temp_dir = tempdir().unwrap();
368        let store_path = temp_dir.path().join("certificates.toml");
369
370        let store = CertificateStore::new(store_path.clone());
371
372        // Test saving empty store
373        store.save().await.unwrap();
374        assert!(store_path.exists());
375
376        // Test loading empty store
377        let mut store2 = CertificateStore::new(store_path);
378        store2.load().await.unwrap();
379        assert_eq!(store2.certificates.len(), 0);
380    }
381
382    #[test]
383    fn test_certificate_expiry() {
384        let now = Utc::now();
385        let cert = Certificate {
386            domain: "example.com".to_string(),
387            cert_path: PathBuf::from("test.crt"),
388            key_path: PathBuf::from("test.key"),
389            issued_at: now - chrono::Duration::days(60),
390            expires_at: now + chrono::Duration::days(30),
391            issuer: "Test CA".to_string(),
392            san_domains: vec!["www.example.com".to_string()],
393            auto_renew: true,
394            last_renewal_check: None,
395        };
396
397        // Allow for small timing differences (29-30 days)
398        let days_until_expiry = cert.days_until_expiry();
399        assert!((29..=30).contains(&days_until_expiry));
400        assert!(cert.needs_renewal(45)); // Should renew if 45 days or less
401        assert!(!cert.needs_renewal(25)); // Should not renew if more than 30 days left
402        assert!(!cert.is_expired());
403        assert!(cert.covers_domain("example.com"));
404        assert!(cert.covers_domain("www.example.com"));
405        assert!(!cert.covers_domain("other.com"));
406    }
407}