bws_web_server/ssl/
certificate.rs

1use chrono::{DateTime, Utc};
2use rustls_pemfile::{certs, private_key};
3use serde::{Deserialize, Serialize};
4use std::io::BufReader;
5use std::path::PathBuf;
6use tokio::fs;
7use x509_parser::prelude::*;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Certificate {
11    pub domain: String,
12    pub cert_path: PathBuf,
13    pub key_path: PathBuf,
14    pub issued_at: DateTime<Utc>,
15    pub expires_at: DateTime<Utc>,
16    pub issuer: String,
17    pub san_domains: Vec<String>,
18    pub auto_renew: bool,
19    pub last_renewal_check: Option<DateTime<Utc>>,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CertificateStore {
24    pub certificates: Vec<Certificate>,
25    pub storage_path: PathBuf,
26}
27
28impl Certificate {
29    pub async fn from_files(
30        domain: String,
31        cert_path: PathBuf,
32        key_path: PathBuf,
33        auto_renew: bool,
34    ) -> Result<Self, Box<dyn std::error::Error>> {
35        // Read certificate file
36        let cert_data = fs::read(&cert_path).await?;
37        let cert_info = Self::parse_certificate(&cert_data)?;
38
39        Ok(Certificate {
40            domain,
41            cert_path,
42            key_path,
43            issued_at: cert_info.issued_at,
44            expires_at: cert_info.expires_at,
45            issuer: cert_info.issuer,
46            san_domains: cert_info.san_domains,
47            auto_renew,
48            last_renewal_check: None,
49        })
50    }
51
52    pub fn from_pem_data(
53        domain: String,
54        cert_path: PathBuf,
55        key_path: PathBuf,
56        cert_pem: &str,
57        auto_renew: bool,
58    ) -> Result<Self, Box<dyn std::error::Error>> {
59        // Parse certificate from PEM data
60        let cert_info = Self::parse_certificate(cert_pem.as_bytes())?;
61
62        Ok(Certificate {
63            domain,
64            cert_path,
65            key_path,
66            issued_at: cert_info.issued_at,
67            expires_at: cert_info.expires_at,
68            issuer: cert_info.issuer,
69            san_domains: cert_info.san_domains,
70            auto_renew,
71            last_renewal_check: None,
72        })
73    }
74
75    pub async fn save_certificate(
76        &self,
77        cert_pem: &str,
78        key_pem: &str,
79    ) -> Result<(), Box<dyn std::error::Error>> {
80        // Ensure parent directories exist
81        if let Some(parent) = self.cert_path.parent() {
82            fs::create_dir_all(parent).await?;
83        }
84        if let Some(parent) = self.key_path.parent() {
85            fs::create_dir_all(parent).await?;
86        }
87
88        // Write certificate and key files
89        fs::write(&self.cert_path, cert_pem).await?;
90        fs::write(&self.key_path, key_pem).await?;
91
92        // Set appropriate permissions (600 for key file)
93        #[cfg(unix)]
94        {
95            use std::os::unix::fs::PermissionsExt;
96            let key_perms = std::fs::Permissions::from_mode(0o600);
97            std::fs::set_permissions(&self.key_path, key_perms)?;
98
99            let cert_perms = std::fs::Permissions::from_mode(0o644);
100            std::fs::set_permissions(&self.cert_path, cert_perms)?;
101        }
102
103        log::info!(
104            "Certificate saved for {} at {} and {}",
105            self.domain,
106            self.cert_path.display(),
107            self.key_path.display()
108        );
109
110        Ok(())
111    }
112
113    pub fn days_until_expiry(&self) -> i64 {
114        let now = Utc::now();
115        (self.expires_at - now).num_days()
116    }
117
118    pub fn needs_renewal(&self, days_before_expiry: i64) -> bool {
119        self.auto_renew && self.days_until_expiry() <= days_before_expiry
120    }
121
122    pub fn is_expired(&self) -> bool {
123        Utc::now() > self.expires_at
124    }
125
126    pub fn covers_domain(&self, domain: &str) -> bool {
127        self.domain == domain || self.san_domains.contains(&domain.to_string())
128    }
129
130    fn parse_certificate(cert_data: &[u8]) -> Result<CertificateInfo, Box<dyn std::error::Error>> {
131        let mut reader = BufReader::new(cert_data);
132        let certs = certs(&mut reader)
133            .collect::<Result<Vec<_>, _>>()
134            .map_err(|e| format!("Failed to parse certificate: {}", e))?;
135
136        if certs.is_empty() {
137            return Err("No certificates found in file".into());
138        }
139
140        let cert = &certs[0];
141        let (_, parsed_cert) = X509Certificate::from_der(cert.as_ref())
142            .map_err(|e| format!("Failed to parse X509 certificate: {}", e))?;
143
144        let issued_at = DateTime::from_timestamp(parsed_cert.validity().not_before.timestamp(), 0)
145            .unwrap_or_else(Utc::now);
146
147        let expires_at = DateTime::from_timestamp(parsed_cert.validity().not_after.timestamp(), 0)
148            .unwrap_or_else(|| Utc::now() + chrono::Duration::days(90));
149
150        let issuer = parsed_cert
151            .issuer()
152            .iter_common_name()
153            .next()
154            .and_then(|cn| cn.as_str().ok())
155            .unwrap_or("Unknown")
156            .to_string();
157
158        // Extract SAN domains - simplified for now
159        let san_domains = Vec::new();
160        // TODO: Implement proper SAN parsing
161        log::debug!("Certificate SAN parsing not implemented - using subject CN only");
162
163        Ok(CertificateInfo {
164            issued_at,
165            expires_at,
166            issuer,
167            san_domains,
168        })
169    }
170
171    pub async fn validate_certificate_files(&self) -> Result<bool, Box<dyn std::error::Error>> {
172        // Check if files exist
173        if !self.cert_path.exists() || !self.key_path.exists() {
174            return Ok(false);
175        }
176
177        // Try to load certificate
178        let cert_data = fs::read(&self.cert_path).await?;
179        let mut cert_reader = BufReader::new(cert_data.as_slice());
180        let certs_result = certs(&mut cert_reader).collect::<Result<Vec<_>, _>>();
181        if certs_result.is_err() {
182            return Ok(false);
183        }
184
185        // Try to load private key
186        let key_data = fs::read(&self.key_path).await?;
187        let mut key_reader = BufReader::new(key_data.as_slice());
188        let key_result = private_key(&mut key_reader);
189        if key_result.is_err() {
190            return Ok(false);
191        }
192
193        Ok(true)
194    }
195
196    pub async fn get_rustls_config(
197        &self,
198    ) -> Result<rustls::ServerConfig, Box<dyn std::error::Error>> {
199        // Load certificate chain
200        let cert_data = fs::read(&self.cert_path).await?;
201        let mut cert_reader = BufReader::new(cert_data.as_slice());
202        let cert_chain = certs(&mut cert_reader)
203            .collect::<Result<Vec<_>, _>>()
204            .map_err(|e| format!("Failed to load certificate: {}", e))?;
205
206        // Load private key
207        let key_data = fs::read(&self.key_path).await?;
208        let mut key_reader = BufReader::new(key_data.as_slice());
209        let private_key = private_key(&mut key_reader)
210            .map_err(|e| format!("Failed to load private key: {}", e))?
211            .ok_or("No private key found")?;
212
213        // Create rustls config
214        let config = rustls::ServerConfig::builder()
215            .with_no_client_auth()
216            .with_single_cert(cert_chain, private_key)
217            .map_err(|e| format!("Invalid certificate/key: {}", e))?;
218
219        Ok(config)
220    }
221}
222
223impl CertificateStore {
224    pub fn new(storage_path: PathBuf) -> Self {
225        Self {
226            certificates: Vec::new(),
227            storage_path,
228        }
229    }
230
231    pub async fn load(&mut self) -> Result<(), Box<dyn std::error::Error>> {
232        if !self.storage_path.exists() {
233            log::info!("Certificate store file not found, starting with empty store");
234            return Ok(());
235        }
236
237        let data = fs::read_to_string(&self.storage_path).await?;
238        let store: CertificateStore = toml::from_str(&data)?;
239        self.certificates = store.certificates;
240
241        log::info!("Loaded {} certificates from store", self.certificates.len());
242        Ok(())
243    }
244
245    pub async fn save(&self) -> Result<(), Box<dyn std::error::Error>> {
246        let data = toml::to_string_pretty(self)?;
247        if let Some(parent) = self.storage_path.parent() {
248            fs::create_dir_all(parent).await?;
249        }
250        fs::write(&self.storage_path, data).await?;
251        log::info!(
252            "Saved certificate store with {} certificates",
253            self.certificates.len()
254        );
255        Ok(())
256    }
257
258    pub fn add_certificate(&mut self, certificate: Certificate) {
259        // Remove any existing certificate for the same domain
260        self.certificates
261            .retain(|cert| cert.domain != certificate.domain);
262        self.certificates.push(certificate);
263    }
264
265    pub fn get_certificate(&self, domain: &str) -> Option<&Certificate> {
266        self.certificates
267            .iter()
268            .find(|cert| cert.covers_domain(domain))
269    }
270
271    pub fn has_certificate(&self, domain: &str) -> bool {
272        self.certificates
273            .iter()
274            .any(|cert| cert.covers_domain(domain))
275    }
276
277    pub fn get_certificates_needing_renewal(&self, days_before_expiry: i64) -> Vec<&Certificate> {
278        self.certificates
279            .iter()
280            .filter(|cert| cert.needs_renewal(days_before_expiry))
281            .collect()
282    }
283
284    pub fn get_expired_certificates(&self) -> Vec<&Certificate> {
285        self.certificates
286            .iter()
287            .filter(|cert| cert.is_expired())
288            .collect()
289    }
290
291    pub fn update_renewal_check(&mut self, domain: &str) {
292        if let Some(cert) = self.certificates.iter_mut().find(|c| c.domain == domain) {
293            cert.last_renewal_check = Some(Utc::now());
294        }
295    }
296
297    pub fn remove_certificate(&mut self, domain: &str) -> bool {
298        let original_len = self.certificates.len();
299        self.certificates.retain(|cert| cert.domain != domain);
300        self.certificates.len() != original_len
301    }
302
303    pub fn list_certificates(&self) -> &[Certificate] {
304        &self.certificates
305    }
306
307    /// Get all domains managed by this certificate store
308    pub fn get_all_domains(&self) -> Vec<String> {
309        self.certificates
310            .iter()
311            .map(|cert| cert.domain.clone())
312            .collect()
313    }
314
315    /// Get certificate expiry date for a domain
316    pub fn get_certificate_expiry(
317        &self,
318        domain: &str,
319    ) -> Result<Option<chrono::DateTime<chrono::Utc>>, Box<dyn std::error::Error + Send + Sync>>
320    {
321        if let Some(cert) = self.certificates.iter().find(|c| c.domain == domain) {
322            Ok(Some(cert.expires_at))
323        } else {
324            Ok(None)
325        }
326    }
327}
328
329#[derive(Debug)]
330struct CertificateInfo {
331    issued_at: DateTime<Utc>,
332    expires_at: DateTime<Utc>,
333    issuer: String,
334    san_domains: Vec<String>,
335}
336
337// Helper functions for certificate management
338pub fn get_certificate_path(domain: &str, cert_dir: &str) -> PathBuf {
339    PathBuf::from(cert_dir).join(format!("{}.crt", domain))
340}
341
342pub fn get_key_path(domain: &str, cert_dir: &str) -> PathBuf {
343    PathBuf::from(cert_dir).join(format!("{}.key", domain))
344}
345
346pub async fn ensure_certificate_directory(
347    cert_dir: &str,
348) -> Result<(), Box<dyn std::error::Error>> {
349    let path = PathBuf::from(cert_dir);
350    if !path.exists() {
351        fs::create_dir_all(&path).await?;
352        log::info!("Created certificate directory: {}", path.display());
353    }
354    Ok(())
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use tempfile::tempdir;
361
362    #[tokio::test]
363    async fn test_certificate_store() {
364        let temp_dir = tempdir().unwrap();
365        let store_path = temp_dir.path().join("certificates.toml");
366
367        let store = CertificateStore::new(store_path.clone());
368
369        // Test saving empty store
370        store.save().await.unwrap();
371        assert!(store_path.exists());
372
373        // Test loading empty store
374        let mut store2 = CertificateStore::new(store_path);
375        store2.load().await.unwrap();
376        assert_eq!(store2.certificates.len(), 0);
377    }
378
379    #[test]
380    fn test_certificate_expiry() {
381        let now = Utc::now();
382        let cert = Certificate {
383            domain: "example.com".to_string(),
384            cert_path: PathBuf::from("test.crt"),
385            key_path: PathBuf::from("test.key"),
386            issued_at: now - chrono::Duration::days(60),
387            expires_at: now + chrono::Duration::days(30),
388            issuer: "Test CA".to_string(),
389            san_domains: vec!["www.example.com".to_string()],
390            auto_renew: true,
391            last_renewal_check: None,
392        };
393
394        // Allow for small timing differences (29-30 days)
395        let days_until_expiry = cert.days_until_expiry();
396        assert!((29..=30).contains(&days_until_expiry));
397        assert!(cert.needs_renewal(45)); // Should renew if 45 days or less
398        assert!(!cert.needs_renewal(25)); // Should not renew if more than 30 days left
399        assert!(!cert.is_expired());
400        assert!(cert.covers_domain("example.com"));
401        assert!(cert.covers_domain("www.example.com"));
402        assert!(!cert.covers_domain("other.com"));
403    }
404}