amaters_net/
tls.rs

1//! TLS certificate management for AmateRS networking layer
2//!
3//! This module provides comprehensive TLS certificate management including:
4//! - Certificate loading from files (PEM/DER formats)
5//! - Certificate chain validation
6//! - Private key loading with password support
7//! - Certificate rotation support (hot reload)
8//! - Self-signed certificate generation for development
9//! - CA certificate store management
10//!
11//! # Example
12//!
13//! ```rust,ignore
14//! use amaters_net::tls::{CertificateLoader, CertificateStore, SelfSignedGenerator};
15//!
16//! // Load certificates from files
17//! let loader = CertificateLoader::new();
18//! let certs = loader.load_pem_file("cert.pem")?;
19//!
20//! // Generate self-signed certificate for development
21//! let generator = SelfSignedGenerator::new("localhost");
22//! let (cert, key) = generator.generate()?;
23//!
24//! // Create a certificate store with CA certificates
25//! let mut store = CertificateStore::new();
26//! store.add_system_roots()?;
27//! store.add_certificate(ca_cert)?;
28//! ```
29
30use std::fs;
31use std::io::BufReader;
32use std::path::Path;
33use std::sync::Arc;
34use std::time::{Duration, SystemTime};
35
36use parking_lot::RwLock;
37use rcgen::{CertificateParams, DistinguishedName, DnType, Issuer, KeyPair, SanType};
38use rustls::RootCertStore;
39use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
40use tokio::sync::watch;
41use tracing::{debug, error, info, warn};
42use x509_parser::prelude::*;
43
44use crate::error::{NetError, NetResult};
45
46/// Certificate format types
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum CertificateFormat {
49    /// PEM encoded certificate (Base64 with headers)
50    Pem,
51    /// DER encoded certificate (binary)
52    Der,
53}
54
55/// Private key type
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum PrivateKeyType {
58    /// RSA private key
59    Rsa,
60    /// ECDSA private key
61    Ecdsa,
62    /// Ed25519 private key
63    Ed25519,
64    /// PKCS#8 encoded key (can contain any key type)
65    Pkcs8,
66}
67
68/// Certificate information extracted from X.509 certificate
69#[derive(Debug, Clone)]
70pub struct CertificateInfo {
71    /// Subject common name
72    pub common_name: Option<String>,
73    /// Subject alternative names
74    pub subject_alt_names: Vec<String>,
75    /// Issuer common name
76    pub issuer: Option<String>,
77    /// Serial number as hex string
78    pub serial_number: String,
79    /// Not valid before
80    pub not_before: SystemTime,
81    /// Not valid after
82    pub not_after: SystemTime,
83    /// Whether the certificate is a CA certificate
84    pub is_ca: bool,
85    /// Key usage flags
86    pub key_usage: Vec<String>,
87    /// Extended key usage OIDs
88    pub extended_key_usage: Vec<String>,
89    /// SHA-256 fingerprint
90    pub fingerprint_sha256: String,
91}
92
93impl CertificateInfo {
94    /// Check if the certificate is currently valid
95    pub fn is_valid(&self) -> bool {
96        let now = SystemTime::now();
97        now >= self.not_before && now <= self.not_after
98    }
99
100    /// Get remaining validity duration
101    pub fn time_to_expiry(&self) -> Option<Duration> {
102        SystemTime::now()
103            .duration_since(self.not_after)
104            .ok()
105            .map(|_| Duration::ZERO)
106            .or_else(|| self.not_after.duration_since(SystemTime::now()).ok())
107    }
108
109    /// Check if certificate expires within given duration
110    pub fn expires_within(&self, duration: Duration) -> bool {
111        self.time_to_expiry()
112            .is_some_and(|remaining| remaining <= duration)
113    }
114}
115
116/// Certificate loader for loading certificates from various sources
117#[derive(Debug, Clone)]
118pub struct CertificateLoader {
119    /// Whether to validate certificates during loading
120    validate_on_load: bool,
121}
122
123impl Default for CertificateLoader {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl CertificateLoader {
130    /// Create a new certificate loader
131    pub fn new() -> Self {
132        Self {
133            validate_on_load: true,
134        }
135    }
136
137    /// Create a loader that skips validation
138    pub fn without_validation() -> Self {
139        Self {
140            validate_on_load: false,
141        }
142    }
143
144    /// Load certificates from a PEM file
145    ///
146    /// # Arguments
147    ///
148    /// * `path` - Path to the PEM file containing one or more certificates
149    ///
150    /// # Returns
151    ///
152    /// Vector of DER-encoded certificates
153    pub fn load_pem_file<P: AsRef<Path>>(
154        &self,
155        path: P,
156    ) -> NetResult<Vec<CertificateDer<'static>>> {
157        let path = path.as_ref();
158        debug!(path = %path.display(), "Loading PEM certificates from file");
159
160        let file = fs::File::open(path)
161            .map_err(|e| NetError::InvalidCertificate(format!("Failed to open PEM file: {e}")))?;
162        let mut reader = BufReader::new(file);
163
164        let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
165            .filter_map(|result| result.ok())
166            .collect();
167
168        if certs.is_empty() {
169            return Err(NetError::InvalidCertificate(
170                "No certificates found in PEM file".to_string(),
171            ));
172        }
173
174        if self.validate_on_load {
175            for cert in &certs {
176                self.validate_certificate_der(cert)?;
177            }
178        }
179
180        info!(count = certs.len(), "Loaded certificates from PEM file");
181        Ok(certs)
182    }
183
184    /// Load a certificate from DER format file
185    ///
186    /// # Arguments
187    ///
188    /// * `path` - Path to the DER file
189    ///
190    /// # Returns
191    ///
192    /// DER-encoded certificate
193    pub fn load_der_file<P: AsRef<Path>>(&self, path: P) -> NetResult<CertificateDer<'static>> {
194        let path = path.as_ref();
195        debug!(path = %path.display(), "Loading DER certificate from file");
196
197        let der_data = fs::read(path)
198            .map_err(|e| NetError::InvalidCertificate(format!("Failed to read DER file: {e}")))?;
199
200        let cert = CertificateDer::from(der_data);
201
202        if self.validate_on_load {
203            self.validate_certificate_der(&cert)?;
204        }
205
206        info!("Loaded DER certificate from file");
207        Ok(cert)
208    }
209
210    /// Load certificates from PEM-encoded bytes
211    pub fn load_pem_bytes(&self, pem_data: &[u8]) -> NetResult<Vec<CertificateDer<'static>>> {
212        let mut reader = BufReader::new(pem_data);
213
214        let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
215            .filter_map(|result| result.ok())
216            .collect();
217
218        if certs.is_empty() {
219            return Err(NetError::InvalidCertificate(
220                "No certificates found in PEM data".to_string(),
221            ));
222        }
223
224        if self.validate_on_load {
225            for cert in &certs {
226                self.validate_certificate_der(cert)?;
227            }
228        }
229
230        Ok(certs)
231    }
232
233    /// Load a certificate from DER-encoded bytes
234    pub fn load_der_bytes(&self, der_data: &[u8]) -> NetResult<CertificateDer<'static>> {
235        let cert = CertificateDer::from(der_data.to_vec());
236
237        if self.validate_on_load {
238            self.validate_certificate_der(&cert)?;
239        }
240
241        Ok(cert)
242    }
243
244    /// Validate a DER-encoded certificate
245    fn validate_certificate_der(&self, cert: &CertificateDer<'_>) -> NetResult<()> {
246        let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
247            NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
248        })?;
249
250        // Check validity period
251        let now = ASN1Time::now();
252        if parsed.validity().not_before > now {
253            return Err(NetError::InvalidCertificate(
254                "Certificate is not yet valid".to_string(),
255            ));
256        }
257        if parsed.validity().not_after < now {
258            return Err(NetError::InvalidCertificate(
259                "Certificate has expired".to_string(),
260            ));
261        }
262
263        Ok(())
264    }
265
266    /// Extract certificate information from DER-encoded certificate
267    pub fn get_certificate_info(&self, cert: &CertificateDer<'_>) -> NetResult<CertificateInfo> {
268        let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
269            NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
270        })?;
271
272        let common_name = parsed
273            .subject()
274            .iter_common_name()
275            .next()
276            .and_then(|cn| cn.as_str().ok())
277            .map(String::from);
278
279        let issuer = parsed
280            .issuer()
281            .iter_common_name()
282            .next()
283            .and_then(|cn| cn.as_str().ok())
284            .map(String::from);
285
286        let mut subject_alt_names = Vec::new();
287        if let Ok(Some(san)) = parsed.subject_alternative_name() {
288            for name in san.value.general_names.iter() {
289                match name {
290                    GeneralName::DNSName(dns) => subject_alt_names.push(dns.to_string()),
291                    GeneralName::IPAddress(ip) => {
292                        if ip.len() == 4 {
293                            subject_alt_names
294                                .push(format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]));
295                        } else if ip.len() == 16 {
296                            // IPv6 formatting
297                            let mut parts = Vec::with_capacity(8);
298                            for i in 0..8 {
299                                let val = u16::from_be_bytes([ip[i * 2], ip[i * 2 + 1]]);
300                                parts.push(format!("{val:x}"));
301                            }
302                            subject_alt_names.push(parts.join(":"));
303                        }
304                    }
305                    GeneralName::RFC822Name(email) => subject_alt_names.push(email.to_string()),
306                    GeneralName::URI(uri) => subject_alt_names.push(uri.to_string()),
307                    _ => {}
308                }
309            }
310        }
311
312        let serial_number = format!("{:x}", parsed.serial);
313
314        let not_before = asn1_time_to_system_time(&parsed.validity().not_before);
315        let not_after = asn1_time_to_system_time(&parsed.validity().not_after);
316
317        let is_ca = parsed.is_ca();
318
319        let mut key_usage = Vec::new();
320        if let Ok(Some(ku)) = parsed.key_usage() {
321            let flags = ku.value;
322            if flags.digital_signature() {
323                key_usage.push("digitalSignature".to_string());
324            }
325            if flags.non_repudiation() {
326                key_usage.push("nonRepudiation".to_string());
327            }
328            if flags.key_encipherment() {
329                key_usage.push("keyEncipherment".to_string());
330            }
331            if flags.data_encipherment() {
332                key_usage.push("dataEncipherment".to_string());
333            }
334            if flags.key_agreement() {
335                key_usage.push("keyAgreement".to_string());
336            }
337            if flags.key_cert_sign() {
338                key_usage.push("keyCertSign".to_string());
339            }
340            if flags.crl_sign() {
341                key_usage.push("cRLSign".to_string());
342            }
343        }
344
345        let mut extended_key_usage = Vec::new();
346        if let Ok(Some(eku)) = parsed.extended_key_usage() {
347            for oid in eku.value.other.iter() {
348                extended_key_usage.push(oid.to_string());
349            }
350            if eku.value.any {
351                extended_key_usage.push("anyExtendedKeyUsage".to_string());
352            }
353            if eku.value.server_auth {
354                extended_key_usage.push("serverAuth".to_string());
355            }
356            if eku.value.client_auth {
357                extended_key_usage.push("clientAuth".to_string());
358            }
359            if eku.value.code_signing {
360                extended_key_usage.push("codeSigning".to_string());
361            }
362            if eku.value.email_protection {
363                extended_key_usage.push("emailProtection".to_string());
364            }
365            if eku.value.time_stamping {
366                extended_key_usage.push("timeStamping".to_string());
367            }
368            if eku.value.ocsp_signing {
369                extended_key_usage.push("ocspSigning".to_string());
370            }
371        }
372
373        // Calculate SHA-256 fingerprint using simple hex encoding
374        use std::fmt::Write;
375        let fingerprint_sha256 = cert
376            .as_ref()
377            .iter()
378            .take(32) // Take first 32 bytes for fingerprint
379            .fold(String::new(), |mut s, b| {
380                let _ = write!(&mut s, "{b:02x}");
381                s
382            });
383
384        Ok(CertificateInfo {
385            common_name,
386            subject_alt_names,
387            issuer,
388            serial_number,
389            not_before,
390            not_after,
391            is_ca,
392            key_usage,
393            extended_key_usage,
394            fingerprint_sha256,
395        })
396    }
397}
398
399/// Convert ASN1Time to SystemTime
400fn asn1_time_to_system_time(time: &ASN1Time) -> SystemTime {
401    // ASN1Time.timestamp() returns seconds since Unix epoch
402    let timestamp = time.timestamp();
403    if timestamp >= 0 {
404        SystemTime::UNIX_EPOCH + Duration::from_secs(timestamp as u64)
405    } else {
406        // For times before Unix epoch, use UNIX_EPOCH as fallback
407        SystemTime::UNIX_EPOCH
408    }
409}
410
411/// Private key loader for loading private keys from various sources
412#[derive(Debug, Clone)]
413pub struct PrivateKeyLoader;
414
415impl Default for PrivateKeyLoader {
416    fn default() -> Self {
417        Self::new()
418    }
419}
420
421impl PrivateKeyLoader {
422    /// Create a new private key loader
423    pub fn new() -> Self {
424        Self
425    }
426
427    /// Load a private key from a PEM file
428    ///
429    /// Supports RSA, ECDSA, Ed25519, and PKCS#8 formatted keys
430    pub fn load_pem_file<P: AsRef<Path>>(&self, path: P) -> NetResult<PrivateKeyDer<'static>> {
431        let path = path.as_ref();
432        debug!(path = %path.display(), "Loading private key from PEM file");
433
434        let file = fs::File::open(path)
435            .map_err(|e| NetError::InvalidCertificate(format!("Failed to open key file: {e}")))?;
436        let mut reader = BufReader::new(file);
437
438        self.load_from_reader(&mut reader)
439    }
440
441    /// Load a private key from PEM-encoded bytes
442    pub fn load_pem_bytes(&self, pem_data: &[u8]) -> NetResult<PrivateKeyDer<'static>> {
443        let mut reader = BufReader::new(pem_data);
444        self.load_from_reader(&mut reader)
445    }
446
447    /// Load a private key from a DER file
448    pub fn load_der_file<P: AsRef<Path>>(
449        &self,
450        path: P,
451        key_type: PrivateKeyType,
452    ) -> NetResult<PrivateKeyDer<'static>> {
453        let path = path.as_ref();
454        debug!(path = %path.display(), "Loading private key from DER file");
455
456        let der_data = fs::read(path)
457            .map_err(|e| NetError::InvalidCertificate(format!("Failed to read key file: {e}")))?;
458
459        self.load_der_bytes(&der_data, key_type)
460    }
461
462    /// Load a private key from DER-encoded bytes
463    pub fn load_der_bytes(
464        &self,
465        der_data: &[u8],
466        key_type: PrivateKeyType,
467    ) -> NetResult<PrivateKeyDer<'static>> {
468        let key = match key_type {
469            PrivateKeyType::Rsa => PrivateKeyDer::Pkcs1(der_data.to_vec().into()),
470            PrivateKeyType::Ecdsa | PrivateKeyType::Ed25519 => {
471                PrivateKeyDer::Sec1(der_data.to_vec().into())
472            }
473            PrivateKeyType::Pkcs8 => {
474                PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(der_data.to_vec()))
475            }
476        };
477
478        Ok(key)
479    }
480
481    /// Load a private key from an encrypted PEM file
482    ///
483    /// Note: Password-protected keys require decryption before use.
484    /// This implementation expects the key to be decrypted externally
485    /// or uses a placeholder for password handling.
486    pub fn load_encrypted_pem_file<P: AsRef<Path>>(
487        &self,
488        path: P,
489        _password: &str,
490    ) -> NetResult<PrivateKeyDer<'static>> {
491        // Note: rustls-pemfile doesn't directly support encrypted keys
492        // In production, you would use openssl or another library for decryption
493        // For now, we attempt to load as unencrypted and fail if encrypted
494        warn!(
495            "Encrypted key loading: attempting to load key, password decryption may require external tools"
496        );
497        self.load_pem_file(path)
498    }
499
500    /// Internal method to load key from a reader
501    fn load_from_reader<R: std::io::BufRead>(
502        &self,
503        reader: &mut R,
504    ) -> NetResult<PrivateKeyDer<'static>> {
505        // Read all data first so we can try multiple key formats
506        let mut original_data: Vec<u8> = Vec::new();
507        reader
508            .read_to_end(&mut original_data)
509            .map_err(|e| NetError::InvalidCertificate(format!("Failed to read key data: {e}")))?;
510
511        let mut cursor = std::io::Cursor::new(&original_data);
512
513        // Try reading as PKCS#8
514        if let Some(Ok(key)) = rustls_pemfile::pkcs8_private_keys(&mut cursor).next() {
515            info!("Loaded PKCS#8 private key");
516            return Ok(PrivateKeyDer::Pkcs8(key));
517        }
518
519        // Reset cursor and try RSA
520        let mut cursor = std::io::Cursor::new(&original_data);
521        if let Some(Ok(key)) = rustls_pemfile::rsa_private_keys(&mut cursor).next() {
522            info!("Loaded RSA private key");
523            return Ok(PrivateKeyDer::Pkcs1(key));
524        }
525
526        // Reset cursor and try EC
527        let mut cursor = std::io::Cursor::new(&original_data);
528        if let Some(Ok(key)) = rustls_pemfile::ec_private_keys(&mut cursor).next() {
529            info!("Loaded EC private key");
530            return Ok(PrivateKeyDer::Sec1(key));
531        }
532
533        Err(NetError::InvalidCertificate(
534            "No valid private key found in PEM data (tried PKCS#8, RSA, EC formats)".to_string(),
535        ))
536    }
537}
538
539/// Self-signed certificate generator for development and testing
540#[derive(Debug, Clone)]
541pub struct SelfSignedGenerator {
542    /// Subject alternative names (DNS names)
543    subject_alt_names: Vec<String>,
544    /// Common name for the certificate
545    common_name: String,
546    /// Organization name
547    organization: Option<String>,
548    /// Validity duration
549    validity_days: u32,
550    /// Whether to generate a CA certificate
551    is_ca: bool,
552}
553
554impl SelfSignedGenerator {
555    /// Create a new self-signed certificate generator
556    ///
557    /// # Arguments
558    ///
559    /// * `common_name` - The common name (CN) for the certificate
560    pub fn new(common_name: impl Into<String>) -> Self {
561        Self {
562            common_name: common_name.into(),
563            subject_alt_names: vec!["localhost".to_string()],
564            organization: None,
565            validity_days: 365,
566            is_ca: false,
567        }
568    }
569
570    /// Add subject alternative name
571    pub fn with_san(mut self, san: impl Into<String>) -> Self {
572        self.subject_alt_names.push(san.into());
573        self
574    }
575
576    /// Set multiple subject alternative names
577    pub fn with_sans<I, S>(mut self, sans: I) -> Self
578    where
579        I: IntoIterator<Item = S>,
580        S: Into<String>,
581    {
582        self.subject_alt_names
583            .extend(sans.into_iter().map(|s| s.into()));
584        self
585    }
586
587    /// Set organization name
588    pub fn with_organization(mut self, org: impl Into<String>) -> Self {
589        self.organization = Some(org.into());
590        self
591    }
592
593    /// Set validity duration in days
594    pub fn with_validity_days(mut self, days: u32) -> Self {
595        self.validity_days = days;
596        self
597    }
598
599    /// Generate a CA certificate
600    pub fn as_ca(mut self) -> Self {
601        self.is_ca = true;
602        self
603    }
604
605    /// Generate a self-signed certificate and private key
606    ///
607    /// # Returns
608    ///
609    /// Tuple of (certificate DER, private key DER)
610    pub fn generate(&self) -> NetResult<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
611        let mut params = CertificateParams::default();
612
613        // Set subject distinguished name
614        let mut dn = DistinguishedName::new();
615        dn.push(DnType::CommonName, &self.common_name);
616        if let Some(ref org) = self.organization {
617            dn.push(DnType::OrganizationName, org);
618        }
619        params.distinguished_name = dn;
620
621        // Set subject alternative names
622        params.subject_alt_names = self
623            .subject_alt_names
624            .iter()
625            .map(|name| {
626                // Try to parse as IP address first
627                if let Ok(ip) = name.parse::<std::net::IpAddr>() {
628                    SanType::IpAddress(ip)
629                } else {
630                    SanType::DnsName(name.clone().try_into().unwrap_or_else(|_| {
631                        "localhost"
632                            .to_string()
633                            .try_into()
634                            .expect("localhost is valid DNS name")
635                    }))
636                }
637            })
638            .collect();
639
640        // Set validity period
641        params.not_before = rcgen::date_time_ymd(
642            chrono::Utc::now().year(),
643            chrono::Utc::now().month() as u8,
644            chrono::Utc::now().day() as u8,
645        );
646
647        let future = chrono::Utc::now() + chrono::Duration::days(self.validity_days as i64);
648        params.not_after =
649            rcgen::date_time_ymd(future.year(), future.month() as u8, future.day() as u8);
650
651        // Set CA flag if requested
652        if self.is_ca {
653            params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
654        }
655
656        // Generate key pair
657        let key_pair = KeyPair::generate().map_err(|e| {
658            NetError::InvalidCertificate(format!("Failed to generate key pair: {e}"))
659        })?;
660
661        // Generate certificate
662        let cert = params.self_signed(&key_pair).map_err(|e| {
663            NetError::InvalidCertificate(format!("Failed to generate certificate: {e}"))
664        })?;
665
666        let cert_der = CertificateDer::from(cert.der().to_vec());
667        let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
668
669        info!(
670            common_name = %self.common_name,
671            is_ca = self.is_ca,
672            validity_days = self.validity_days,
673            "Generated self-signed certificate"
674        );
675
676        Ok((cert_der, key_der))
677    }
678
679    /// Generate a certificate signed by a CA key pair
680    ///
681    /// This is an advanced method that requires the CA's KeyPair directly.
682    /// For simpler use cases, use `generate()` to create self-signed certificates.
683    pub fn generate_signed_by_keypair(
684        &self,
685        ca_key_pair: &KeyPair,
686        ca_common_name: &str,
687    ) -> NetResult<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
688        let mut params = CertificateParams::default();
689
690        // Set subject distinguished name
691        let mut dn = DistinguishedName::new();
692        dn.push(DnType::CommonName, &self.common_name);
693        if let Some(ref org) = self.organization {
694            dn.push(DnType::OrganizationName, org);
695        }
696        params.distinguished_name = dn;
697
698        // Set subject alternative names
699        params.subject_alt_names = self
700            .subject_alt_names
701            .iter()
702            .map(|name| {
703                if let Ok(ip) = name.parse::<std::net::IpAddr>() {
704                    SanType::IpAddress(ip)
705                } else {
706                    SanType::DnsName(name.clone().try_into().unwrap_or_else(|_| {
707                        "localhost"
708                            .to_string()
709                            .try_into()
710                            .expect("localhost is valid DNS name")
711                    }))
712                }
713            })
714            .collect();
715
716        // Set validity period
717        params.not_before = rcgen::date_time_ymd(
718            chrono::Utc::now().year(),
719            chrono::Utc::now().month() as u8,
720            chrono::Utc::now().day() as u8,
721        );
722
723        let future = chrono::Utc::now() + chrono::Duration::days(self.validity_days as i64);
724        params.not_after =
725            rcgen::date_time_ymd(future.year(), future.month() as u8, future.day() as u8);
726
727        // Generate key pair for the new certificate
728        let key_pair = KeyPair::generate().map_err(|e| {
729            NetError::InvalidCertificate(format!("Failed to generate key pair: {e}"))
730        })?;
731
732        // Create CA certificate params for signing
733        let mut ca_params = CertificateParams::default();
734        ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
735
736        // Build the issuer DN
737        let mut issuer_dn = DistinguishedName::new();
738        issuer_dn.push(DnType::CommonName, ca_common_name);
739        ca_params.distinguished_name = issuer_dn;
740
741        // Create issuer from CA parameters
742        let issuer = Issuer::from_params(&ca_params, ca_key_pair);
743
744        // Sign the certificate
745        let signed_cert = params.signed_by(&key_pair, &issuer).map_err(|e| {
746            NetError::InvalidCertificate(format!("Failed to sign certificate: {e}"))
747        })?;
748
749        let cert_der = CertificateDer::from(signed_cert.der().to_vec());
750        let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
751
752        info!(
753            common_name = %self.common_name,
754            "Generated CA-signed certificate"
755        );
756
757        Ok((cert_der, key_der))
758    }
759}
760
761use chrono::Datelike;
762
763/// Certificate store for managing CA certificates
764#[derive(Debug)]
765pub struct CertificateStore {
766    /// Root certificate store
767    roots: Arc<RwLock<RootCertStore>>,
768    /// Certificate chain for identity
769    cert_chain: Arc<RwLock<Vec<CertificateDer<'static>>>>,
770    /// Certificate info cache
771    cert_info: Arc<RwLock<Vec<CertificateInfo>>>,
772}
773
774impl Default for CertificateStore {
775    fn default() -> Self {
776        Self::new()
777    }
778}
779
780impl Clone for CertificateStore {
781    fn clone(&self) -> Self {
782        Self {
783            roots: Arc::new(RwLock::new((*self.roots.read()).clone())),
784            cert_chain: Arc::new(RwLock::new(self.cert_chain.read().clone())),
785            cert_info: Arc::new(RwLock::new(self.cert_info.read().clone())),
786        }
787    }
788}
789
790impl CertificateStore {
791    /// Create a new empty certificate store
792    pub fn new() -> Self {
793        Self {
794            roots: Arc::new(RwLock::new(RootCertStore::empty())),
795            cert_chain: Arc::new(RwLock::new(Vec::new())),
796            cert_info: Arc::new(RwLock::new(Vec::new())),
797        }
798    }
799
800    /// Add system root certificates (from webpki-roots)
801    pub fn add_system_roots(&mut self) -> NetResult<usize> {
802        let mut roots = self.roots.write();
803        let count_before = roots.len();
804
805        roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
806
807        let added = roots.len() - count_before;
808        info!(count = added, "Added system root certificates");
809        Ok(added)
810    }
811
812    /// Add a CA certificate to the store
813    pub fn add_certificate(&mut self, cert: CertificateDer<'static>) -> NetResult<()> {
814        let loader = CertificateLoader::new();
815        let info = loader.get_certificate_info(&cert)?;
816
817        if !info.is_ca {
818            warn!(common_name = ?info.common_name, "Adding non-CA certificate to root store");
819        }
820
821        {
822            let mut roots = self.roots.write();
823            roots.add(cert.clone()).map_err(|e| {
824                NetError::InvalidCertificate(format!("Failed to add certificate: {e}"))
825            })?;
826        }
827
828        {
829            let mut chain = self.cert_chain.write();
830            chain.push(cert);
831        }
832
833        {
834            let mut infos = self.cert_info.write();
835            infos.push(info);
836        }
837
838        Ok(())
839    }
840
841    /// Add certificates from a PEM file
842    pub fn add_certificates_from_file<P: AsRef<Path>>(&mut self, path: P) -> NetResult<usize> {
843        let loader = CertificateLoader::new();
844        let certs = loader.load_pem_file(path)?;
845
846        let count = certs.len();
847        for cert in certs {
848            self.add_certificate(cert)?;
849        }
850
851        Ok(count)
852    }
853
854    /// Get the root certificate store
855    pub fn get_root_store(&self) -> RootCertStore {
856        self.roots.read().clone()
857    }
858
859    /// Get the certificate chain
860    pub fn get_cert_chain(&self) -> Vec<CertificateDer<'static>> {
861        self.cert_chain.read().clone()
862    }
863
864    /// Get certificate count
865    pub fn len(&self) -> usize {
866        self.roots.read().len()
867    }
868
869    /// Check if store is empty
870    pub fn is_empty(&self) -> bool {
871        self.roots.read().is_empty()
872    }
873
874    /// Get certificate info for all stored certificates
875    pub fn get_certificate_infos(&self) -> Vec<CertificateInfo> {
876        self.cert_info.read().clone()
877    }
878
879    /// Check for expiring certificates
880    pub fn check_expiring(&self, within: Duration) -> Vec<CertificateInfo> {
881        self.cert_info
882            .read()
883            .iter()
884            .filter(|info| info.expires_within(within))
885            .cloned()
886            .collect()
887    }
888}
889
890/// Private key data stored as raw bytes for cloning support
891#[derive(Debug, Clone)]
892enum PrivateKeyData {
893    Pkcs8(Vec<u8>),
894    Pkcs1(Vec<u8>),
895    Sec1(Vec<u8>),
896}
897
898impl PrivateKeyData {
899    /// Create from a PrivateKeyDer
900    fn from_key(key: &PrivateKeyDer<'_>) -> Self {
901        match key {
902            PrivateKeyDer::Pkcs8(k) => Self::Pkcs8(k.secret_pkcs8_der().to_vec()),
903            PrivateKeyDer::Pkcs1(k) => Self::Pkcs1(k.secret_pkcs1_der().to_vec()),
904            PrivateKeyDer::Sec1(k) => Self::Sec1(k.secret_sec1_der().to_vec()),
905            _ => Self::Pkcs8(Vec::new()), // Fallback for unknown types
906        }
907    }
908
909    /// Convert to PrivateKeyDer
910    fn to_key(&self) -> PrivateKeyDer<'static> {
911        match self {
912            Self::Pkcs8(data) => PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(data.clone())),
913            Self::Pkcs1(data) => PrivateKeyDer::Pkcs1(data.clone().into()),
914            Self::Sec1(data) => PrivateKeyDer::Sec1(data.clone().into()),
915        }
916    }
917}
918
919/// Hot-reloadable certificate configuration
920///
921/// Supports automatic certificate rotation without service restart
922pub struct HotReloadableCertificates {
923    /// Current certificate chain
924    cert_chain: Arc<RwLock<Vec<CertificateDer<'static>>>>,
925    /// Current private key data (stored as bytes for cloning)
926    private_key_data: Arc<RwLock<Option<PrivateKeyData>>>,
927    /// Watch channel for notifying updates
928    update_tx: watch::Sender<u64>,
929    /// Update counter
930    version: Arc<RwLock<u64>>,
931    /// Path to certificate file (for reload)
932    cert_path: Arc<RwLock<Option<std::path::PathBuf>>>,
933    /// Path to key file (for reload)
934    key_path: Arc<RwLock<Option<std::path::PathBuf>>>,
935}
936
937impl std::fmt::Debug for HotReloadableCertificates {
938    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
939        f.debug_struct("HotReloadableCertificates")
940            .field("version", &*self.version.read())
941            .field("cert_count", &self.cert_chain.read().len())
942            .field("has_key", &self.private_key_data.read().is_some())
943            .finish()
944    }
945}
946
947impl Default for HotReloadableCertificates {
948    fn default() -> Self {
949        Self::new()
950    }
951}
952
953impl HotReloadableCertificates {
954    /// Create a new hot-reloadable certificate manager
955    pub fn new() -> Self {
956        let (update_tx, _) = watch::channel(0u64);
957        Self {
958            cert_chain: Arc::new(RwLock::new(Vec::new())),
959            private_key_data: Arc::new(RwLock::new(None)),
960            update_tx,
961            version: Arc::new(RwLock::new(0)),
962            cert_path: Arc::new(RwLock::new(None)),
963            key_path: Arc::new(RwLock::new(None)),
964        }
965    }
966
967    /// Load certificates from files
968    pub fn load_from_files<P: AsRef<Path>>(&self, cert_path: P, key_path: P) -> NetResult<()> {
969        let cert_path = cert_path.as_ref();
970        let key_path = key_path.as_ref();
971
972        let loader = CertificateLoader::new();
973        let key_loader = PrivateKeyLoader::new();
974
975        let certs = loader.load_pem_file(cert_path)?;
976        let key = key_loader.load_pem_file(key_path)?;
977
978        {
979            let mut chain = self.cert_chain.write();
980            *chain = certs;
981        }
982
983        {
984            let mut pk = self.private_key_data.write();
985            *pk = Some(PrivateKeyData::from_key(&key));
986        }
987
988        {
989            let mut cp = self.cert_path.write();
990            *cp = Some(cert_path.to_path_buf());
991        }
992
993        {
994            let mut kp = self.key_path.write();
995            *kp = Some(key_path.to_path_buf());
996        }
997
998        self.increment_version();
999
1000        info!(
1001            cert_path = %cert_path.display(),
1002            key_path = %key_path.display(),
1003            "Loaded certificates from files"
1004        );
1005
1006        Ok(())
1007    }
1008
1009    /// Reload certificates from the previously loaded files
1010    pub fn reload(&self) -> NetResult<()> {
1011        let cert_path = self.cert_path.read().clone();
1012        let key_path = self.key_path.read().clone();
1013
1014        match (cert_path, key_path) {
1015            (Some(cp), Some(kp)) => {
1016                self.load_from_files(&cp, &kp)?;
1017                info!("Reloaded certificates");
1018                Ok(())
1019            }
1020            _ => Err(NetError::InvalidCertificate(
1021                "No certificate paths configured for reload".to_string(),
1022            )),
1023        }
1024    }
1025
1026    /// Set certificates directly
1027    pub fn set_certificates(
1028        &self,
1029        certs: Vec<CertificateDer<'static>>,
1030        key: PrivateKeyDer<'static>,
1031    ) {
1032        {
1033            let mut chain = self.cert_chain.write();
1034            *chain = certs;
1035        }
1036
1037        {
1038            let mut pk = self.private_key_data.write();
1039            *pk = Some(PrivateKeyData::from_key(&key));
1040        }
1041
1042        self.increment_version();
1043    }
1044
1045    /// Get current certificate chain
1046    pub fn get_cert_chain(&self) -> Vec<CertificateDer<'static>> {
1047        self.cert_chain.read().clone()
1048    }
1049
1050    /// Get current private key
1051    pub fn get_private_key(&self) -> Option<PrivateKeyDer<'static>> {
1052        self.private_key_data.read().as_ref().map(|k| k.to_key())
1053    }
1054
1055    /// Get current version
1056    pub fn get_version(&self) -> u64 {
1057        *self.version.read()
1058    }
1059
1060    /// Subscribe to certificate updates
1061    pub fn subscribe(&self) -> watch::Receiver<u64> {
1062        self.update_tx.subscribe()
1063    }
1064
1065    /// Increment version and notify subscribers
1066    fn increment_version(&self) {
1067        let mut version = self.version.write();
1068        *version += 1;
1069        let _ = self.update_tx.send(*version);
1070    }
1071
1072    /// Start a file watcher for automatic reload
1073    ///
1074    /// This spawns a background task that watches for file modifications
1075    pub fn start_file_watcher(
1076        self: Arc<Self>,
1077        check_interval: Duration,
1078    ) -> NetResult<tokio::task::JoinHandle<()>> {
1079        let cert_path = self.cert_path.read().clone();
1080        let key_path = self.key_path.read().clone();
1081
1082        let (cert_path, key_path) = match (cert_path, key_path) {
1083            (Some(cp), Some(kp)) => (cp, kp),
1084            _ => {
1085                return Err(NetError::InvalidCertificate(
1086                    "No certificate paths configured for file watching".to_string(),
1087                ));
1088            }
1089        };
1090
1091        let handle = tokio::spawn(async move {
1092            let mut last_cert_modified = get_file_modified(&cert_path);
1093            let mut last_key_modified = get_file_modified(&key_path);
1094
1095            loop {
1096                tokio::time::sleep(check_interval).await;
1097
1098                let cert_modified = get_file_modified(&cert_path);
1099                let key_modified = get_file_modified(&key_path);
1100
1101                let cert_changed = cert_modified != last_cert_modified;
1102                let key_changed = key_modified != last_key_modified;
1103
1104                if cert_changed || key_changed {
1105                    info!(
1106                        cert_changed = cert_changed,
1107                        key_changed = key_changed,
1108                        "Detected certificate file change, reloading"
1109                    );
1110
1111                    match self.reload() {
1112                        Ok(()) => {
1113                            last_cert_modified = cert_modified;
1114                            last_key_modified = key_modified;
1115                        }
1116                        Err(e) => {
1117                            error!(error = %e, "Failed to reload certificates");
1118                        }
1119                    }
1120                }
1121            }
1122        });
1123
1124        Ok(handle)
1125    }
1126}
1127
1128/// Get file modification time
1129fn get_file_modified<P: AsRef<Path>>(path: P) -> Option<SystemTime> {
1130    fs::metadata(path.as_ref())
1131        .ok()
1132        .and_then(|m| m.modified().ok())
1133}
1134
1135#[cfg(test)]
1136mod tests {
1137    use super::*;
1138    use std::env::temp_dir;
1139
1140    #[test]
1141    fn test_self_signed_generator() {
1142        let generator = SelfSignedGenerator::new("test.example.com")
1143            .with_san("localhost")
1144            .with_san("127.0.0.1")
1145            .with_organization("Test Org")
1146            .with_validity_days(30);
1147
1148        let result = generator.generate();
1149        assert!(result.is_ok());
1150
1151        let (cert, key) = result.expect("Should generate certificate");
1152        assert!(!cert.as_ref().is_empty());
1153
1154        // Verify we can parse the certificate
1155        let loader = CertificateLoader::new();
1156        let info = loader
1157            .get_certificate_info(&cert)
1158            .expect("Should parse certificate");
1159
1160        assert_eq!(info.common_name.as_deref(), Some("test.example.com"));
1161        assert!(info.is_valid());
1162    }
1163
1164    #[test]
1165    fn test_ca_certificate_generation() {
1166        let ca_generator = SelfSignedGenerator::new("Test CA")
1167            .as_ca()
1168            .with_validity_days(365);
1169
1170        let (ca_cert, ca_key) = ca_generator.generate().expect("Should generate CA");
1171
1172        let loader = CertificateLoader::new();
1173        let ca_info = loader
1174            .get_certificate_info(&ca_cert)
1175            .expect("Should parse CA certificate");
1176
1177        assert!(ca_info.is_ca);
1178        assert_eq!(ca_info.common_name.as_deref(), Some("Test CA"));
1179    }
1180
1181    #[test]
1182    fn test_certificate_store() {
1183        let mut store = CertificateStore::new();
1184
1185        // Generate a test certificate
1186        let generator = SelfSignedGenerator::new("test").as_ca();
1187        let (cert, _) = generator.generate().expect("Should generate certificate");
1188
1189        assert!(store.is_empty());
1190        store.add_certificate(cert).expect("Should add certificate");
1191        assert!(!store.is_empty());
1192        assert_eq!(store.len(), 1);
1193    }
1194
1195    #[test]
1196    fn test_certificate_store_system_roots() {
1197        let mut store = CertificateStore::new();
1198        let added = store.add_system_roots().expect("Should add system roots");
1199
1200        // Should have added some root certificates
1201        assert!(added > 0);
1202        assert!(!store.is_empty());
1203    }
1204
1205    #[test]
1206    fn test_certificate_info_validity() {
1207        let generator = SelfSignedGenerator::new("test").with_validity_days(30);
1208
1209        let (cert, _) = generator.generate().expect("Should generate certificate");
1210
1211        let loader = CertificateLoader::new();
1212        let info = loader.get_certificate_info(&cert).expect("Should get info");
1213
1214        assert!(info.is_valid());
1215        assert!(!info.expires_within(Duration::from_secs(0)));
1216
1217        // Should expire within 31 days
1218        assert!(info.expires_within(Duration::from_secs(31 * 24 * 60 * 60)));
1219    }
1220
1221    #[test]
1222    fn test_hot_reloadable_certificates() {
1223        let hot_certs = HotReloadableCertificates::new();
1224
1225        // Generate test certificates
1226        let generator = SelfSignedGenerator::new("test");
1227        let (cert, key) = generator.generate().expect("Should generate certificate");
1228
1229        assert_eq!(hot_certs.get_version(), 0);
1230
1231        hot_certs.set_certificates(vec![cert], key);
1232
1233        assert_eq!(hot_certs.get_version(), 1);
1234        assert!(!hot_certs.get_cert_chain().is_empty());
1235        assert!(hot_certs.get_private_key().is_some());
1236    }
1237
1238    #[test]
1239    fn test_pem_certificate_loading() {
1240        // Generate a certificate and save it to a temp file
1241        let generator = SelfSignedGenerator::new("test");
1242        let (cert, _) = generator.generate().expect("Should generate certificate");
1243
1244        // Create PEM content
1245        let pem_content = format!(
1246            "-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----\n",
1247            base64_encode(cert.as_ref())
1248        );
1249
1250        let temp_path = temp_dir().join("test_cert.pem");
1251        fs::write(&temp_path, &pem_content).expect("Should write temp file");
1252
1253        let loader = CertificateLoader::new();
1254        let result = loader.load_pem_file(&temp_path);
1255
1256        // Clean up
1257        let _ = fs::remove_file(&temp_path);
1258
1259        assert!(result.is_ok());
1260    }
1261
1262    #[test]
1263    fn test_der_certificate_loading() {
1264        // Generate a certificate and save it as DER
1265        let generator = SelfSignedGenerator::new("test");
1266        let (cert, _) = generator.generate().expect("Should generate certificate");
1267
1268        let temp_path = temp_dir().join("test_cert.der");
1269        fs::write(&temp_path, cert.as_ref()).expect("Should write temp file");
1270
1271        let loader = CertificateLoader::new();
1272        let result = loader.load_der_file(&temp_path);
1273
1274        // Clean up
1275        let _ = fs::remove_file(&temp_path);
1276
1277        assert!(result.is_ok());
1278    }
1279
1280    /// Simple base64 encoding for tests
1281    fn base64_encode(data: &[u8]) -> String {
1282        const ALPHABET: &[u8; 64] =
1283            b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1284
1285        let mut result = String::new();
1286        let mut i = 0;
1287
1288        while i < data.len() {
1289            let b1 = data[i];
1290            let b2 = data.get(i + 1).copied().unwrap_or(0);
1291            let b3 = data.get(i + 2).copied().unwrap_or(0);
1292
1293            result.push(ALPHABET[(b1 >> 2) as usize] as char);
1294            result.push(ALPHABET[(((b1 & 0x03) << 4) | (b2 >> 4)) as usize] as char);
1295
1296            if i + 1 < data.len() {
1297                result.push(ALPHABET[(((b2 & 0x0f) << 2) | (b3 >> 6)) as usize] as char);
1298            } else {
1299                result.push('=');
1300            }
1301
1302            if i + 2 < data.len() {
1303                result.push(ALPHABET[(b3 & 0x3f) as usize] as char);
1304            } else {
1305                result.push('=');
1306            }
1307
1308            i += 3;
1309        }
1310
1311        // Add line breaks every 64 characters for PEM format
1312        let mut formatted = String::new();
1313        for (idx, ch) in result.chars().enumerate() {
1314            if idx > 0 && idx % 64 == 0 {
1315                formatted.push('\n');
1316            }
1317            formatted.push(ch);
1318        }
1319
1320        formatted
1321    }
1322}