Skip to main content

synapse_pingora/
tls.rs

1//! TLS certificate management with SNI-based certificate selection.
2//!
3//! This module provides secure TLS configuration loading and hot-reload
4//! capabilities for multi-site certificate management.
5
6use ahash::RandomState;
7use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::fs;
10use std::path::Path;
11use std::sync::Arc;
12use tracing::{debug, info, warn};
13use zeroize::Zeroize;
14
15/// Maximum certificate file size (1MB).
16const MAX_CERT_SIZE: u64 = 1024 * 1024;
17
18/// TLS certificate and key pair.
19///
20/// # Security (SEC-010)
21/// Private key is wrapped with zeroize to clear memory on drop.
22#[derive(Clone)]
23pub struct CertifiedKey {
24    /// PEM-encoded certificate chain
25    pub cert_pem: Arc<String>,
26    /// PEM-encoded private key (stored securely, zeroized on drop via Arc)
27    pub key_pem: Arc<SecureString>,
28    /// Associated domain
29    pub domain: String,
30}
31
32/// Wrapper for sensitive string data that zeroizes on drop.
33///
34/// # Security (SEC-010)
35/// Ensures private key material is wiped from memory when no longer needed.
36#[derive(Clone)]
37pub struct SecureString(String);
38
39impl SecureString {
40    pub fn new(s: String) -> Self {
41        Self(s)
42    }
43
44    pub fn len(&self) -> usize {
45        self.0.len()
46    }
47
48    pub fn is_empty(&self) -> bool {
49        self.0.is_empty()
50    }
51
52    pub fn as_str(&self) -> &str {
53        &self.0
54    }
55}
56
57impl Drop for SecureString {
58    fn drop(&mut self) {
59        self.0.zeroize();
60    }
61}
62
63impl std::fmt::Debug for SecureString {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        write!(f, "[REDACTED {} bytes]", self.0.len())
66    }
67}
68
69impl std::fmt::Debug for CertifiedKey {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        // Never log the private key
72        f.debug_struct("CertifiedKey")
73            .field("domain", &self.domain)
74            .field("cert_pem", &format!("[{} bytes]", self.cert_pem.len()))
75            .field(
76                "key_pem",
77                &format!("[REDACTED {} bytes]", self.key_pem.len()),
78            )
79            .finish()
80    }
81}
82
83/// Configuration for loading a TLS certificate.
84#[derive(Debug, Clone)]
85pub struct TlsCertConfig {
86    /// Domain name (for SNI matching)
87    pub domain: String,
88    /// Path to certificate file (PEM format)
89    pub cert_path: String,
90    /// Path to private key file (PEM format)
91    pub key_path: String,
92    /// Whether this is a wildcard certificate
93    pub is_wildcard: bool,
94}
95
96/// Result of a certificate reload operation.
97#[derive(Debug, Clone)]
98pub struct ReloadResult {
99    /// Number of certificates successfully reloaded
100    pub succeeded: usize,
101    /// Number of certificates that failed to reload
102    pub failed: usize,
103    /// Errors encountered during reload (domain -> error message)
104    pub errors: Vec<(String, String)>,
105}
106
107impl ReloadResult {
108    /// Returns true if all certificates were reloaded successfully.
109    pub fn is_success(&self) -> bool {
110        self.failed == 0
111    }
112}
113
114/// TLS manager with SNI-based certificate selection and hot reload.
115///
116/// # Performance (PERF-P2-2)
117/// Uses ahash::RandomState for 2-3x faster HashMap operations.
118pub struct TlsManager {
119    /// Exact domain -> certificate mapping (using fast ahash)
120    exact_certs: RwLock<HashMap<String, Arc<CertifiedKey>, RandomState>>,
121    /// Wildcard domain -> certificate mapping (e.g., "example.com" for *.example.com)
122    wildcard_certs: RwLock<HashMap<String, Arc<CertifiedKey>, RandomState>>,
123    /// Default certificate (if any)
124    default_cert: RwLock<Option<Arc<CertifiedKey>>>,
125    /// Minimum TLS version
126    min_version: TlsVersion,
127    /// Stored configurations for hot-reload (domain -> config)
128    cert_configs: RwLock<HashMap<String, TlsCertConfig, RandomState>>,
129    /// Default certificate config for hot-reload
130    default_cert_config: RwLock<Option<TlsCertConfig>>,
131}
132
133use std::str::FromStr;
134
135/// Supported TLS versions.
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
137pub enum TlsVersion {
138    Tls12,
139    Tls13,
140}
141
142impl FromStr for TlsVersion {
143    type Err = TlsError;
144
145    /// Parses a TLS version string.
146    fn from_str(s: &str) -> Result<Self, TlsError> {
147        match s {
148            "1.2" | "TLS1.2" | "TLSv1.2" => Ok(TlsVersion::Tls12),
149            "1.3" | "TLS1.3" | "TLSv1.3" => Ok(TlsVersion::Tls13),
150            _ => Err(TlsError::InvalidVersion {
151                version: s.to_string(),
152            }),
153        }
154    }
155}
156
157/// Errors that can occur during TLS operations.
158#[derive(Debug, thiserror::Error)]
159pub enum TlsError {
160    #[error("certificate file not found: {path}")]
161    CertNotFound { path: String },
162
163    #[error("key file not found: {path}")]
164    KeyNotFound { path: String },
165
166    #[error("certificate file too large: {path} ({size} bytes, max {max} bytes)")]
167    CertTooLarge { path: String, size: u64, max: u64 },
168
169    #[error("key file too large: {path} ({size} bytes, max {max} bytes)")]
170    KeyTooLarge { path: String, size: u64, max: u64 },
171
172    #[error("failed to read certificate: {0}")]
173    ReadError(#[from] std::io::Error),
174
175    #[error("invalid TLS version: {version} (must be 1.2 or 1.3)")]
176    InvalidVersion { version: String },
177
178    #[error("path traversal detected in: {path}")]
179    PathTraversal { path: String },
180
181    #[error("no certificate found for domain: {domain}")]
182    NoCertificate { domain: String },
183
184    #[error("invalid certificate format: {reason}")]
185    InvalidCertificate { reason: String },
186}
187
188impl TlsManager {
189    /// Creates a new TLS manager with the specified minimum version.
190    pub fn new(min_version: TlsVersion) -> Self {
191        Self {
192            exact_certs: RwLock::new(HashMap::with_hasher(RandomState::new())),
193            wildcard_certs: RwLock::new(HashMap::with_hasher(RandomState::new())),
194            default_cert: RwLock::new(None),
195            min_version,
196            cert_configs: RwLock::new(HashMap::with_hasher(RandomState::new())),
197            default_cert_config: RwLock::new(None),
198        }
199    }
200
201    /// Creates a TLS manager with TLS 1.2 minimum.
202    pub fn with_tls12_minimum() -> Self {
203        Self::new(TlsVersion::Tls12)
204    }
205
206    /// Loads a certificate from files.
207    ///
208    /// # Security
209    /// - Validates file paths for traversal attacks
210    /// - Enforces file size limits
211    /// - Never logs private key paths or contents
212    pub fn load_cert(&self, config: &TlsCertConfig) -> Result<(), TlsError> {
213        // Validate paths for traversal
214        Self::validate_path(&config.cert_path)?;
215        Self::validate_path(&config.key_path)?;
216
217        // Load certificate
218        let cert_pem = Self::read_file_secure(&config.cert_path, MAX_CERT_SIZE, "certificate")?;
219
220        // Load private key
221        let key_pem = Self::read_file_secure(&config.key_path, MAX_CERT_SIZE, "key")?;
222
223        // Create certified key (Arc for efficient sharing)
224        // SEC-010: Private key wrapped in SecureString for zeroization
225        let certified_key = Arc::new(CertifiedKey {
226            cert_pem: Arc::new(cert_pem),
227            key_pem: Arc::new(SecureString::new(key_pem)),
228            domain: config.domain.clone(),
229        });
230
231        // Store based on type
232        let storage_key = if config.is_wildcard {
233            // Store wildcard by base domain (e.g., "example.com" for *.example.com)
234            let base_domain = config.domain.trim_start_matches("*.");
235            let mut wildcards = self.wildcard_certs.write();
236            wildcards.insert(base_domain.to_lowercase(), certified_key);
237            info!("Loaded wildcard TLS certificate for *.{}", base_domain);
238            base_domain.to_lowercase()
239        } else {
240            let mut exact = self.exact_certs.write();
241            exact.insert(config.domain.to_lowercase(), certified_key);
242            debug!("Loaded TLS certificate for {}", config.domain);
243            config.domain.to_lowercase()
244        };
245
246        // Store config for hot-reload capability
247        {
248            let mut configs = self.cert_configs.write();
249            configs.insert(storage_key, config.clone());
250        }
251
252        Ok(())
253    }
254
255    /// Sets the default certificate for unmatched domains.
256    pub fn set_default_cert(&self, config: &TlsCertConfig) -> Result<(), TlsError> {
257        Self::validate_path(&config.cert_path)?;
258        Self::validate_path(&config.key_path)?;
259
260        let cert_pem = Self::read_file_secure(&config.cert_path, MAX_CERT_SIZE, "certificate")?;
261        let key_pem = Self::read_file_secure(&config.key_path, MAX_CERT_SIZE, "key")?;
262
263        // SEC-010: Private key wrapped in SecureString for zeroization
264        let certified_key = Arc::new(CertifiedKey {
265            cert_pem: Arc::new(cert_pem),
266            key_pem: Arc::new(SecureString::new(key_pem)),
267            domain: config.domain.clone(),
268        });
269
270        *self.default_cert.write() = Some(certified_key);
271
272        // Store config for hot-reload capability
273        *self.default_cert_config.write() = Some(config.clone());
274
275        info!("Set default TLS certificate for {}", config.domain);
276        Ok(())
277    }
278
279    /// Gets the certificate for a domain using SNI matching.
280    ///
281    /// # Matching Order
282    /// 1. Exact domain match
283    /// 2. Wildcard match (*.example.com matches sub.example.com)
284    /// 3. Default certificate
285    pub fn get_cert(&self, domain: &str) -> Option<Arc<CertifiedKey>> {
286        let normalized = domain.to_lowercase();
287
288        // Try exact match first
289        {
290            let exact = self.exact_certs.read();
291            if let Some(cert) = exact.get(&normalized) {
292                debug!("SNI exact match for {}", domain);
293                return Some(Arc::clone(cert));
294            }
295        }
296
297        // Try wildcard match
298        if let Some(base_domain) = Self::get_base_domain(&normalized) {
299            let wildcards = self.wildcard_certs.read();
300            if let Some(cert) = wildcards.get(base_domain) {
301                debug!("SNI wildcard match for {} -> *.{}", domain, base_domain);
302                return Some(Arc::clone(cert));
303            }
304        }
305
306        // Fall back to default
307        {
308            let default = self.default_cert.read();
309            if let Some(cert) = default.as_ref() {
310                debug!("Using default certificate for {}", domain);
311                return Some(Arc::clone(cert));
312            }
313        }
314
315        warn!("No TLS certificate found for domain: {}", domain);
316        None
317    }
318
319    /// Gets the base domain for wildcard matching.
320    /// e.g., "sub.example.com" -> "example.com"
321    fn get_base_domain(domain: &str) -> Option<&str> {
322        let parts: Vec<&str> = domain.split('.').collect();
323        if parts.len() >= 2 {
324            // Skip the first part (subdomain)
325            let base_start = domain.find('.').map(|i| i + 1)?;
326            Some(&domain[base_start..])
327        } else {
328            None
329        }
330    }
331
332    /// Validates a file path for security issues.
333    fn validate_path(path: &str) -> Result<(), TlsError> {
334        // Check for path traversal
335        if path.contains("..") {
336            return Err(TlsError::PathTraversal {
337                path: path.to_string(),
338            });
339        }
340        Ok(())
341    }
342
343    /// Reads a file with size validation.
344    fn read_file_secure(path: &str, max_size: u64, file_type: &str) -> Result<String, TlsError> {
345        let path_ref = Path::new(path);
346
347        if !path_ref.exists() {
348            return Err(if file_type == "certificate" {
349                TlsError::CertNotFound {
350                    path: path.to_string(),
351                }
352            } else {
353                TlsError::KeyNotFound {
354                    path: path.to_string(),
355                }
356            });
357        }
358
359        let metadata = fs::metadata(path)?;
360        if metadata.len() > max_size {
361            return Err(if file_type == "certificate" {
362                TlsError::CertTooLarge {
363                    path: path.to_string(),
364                    size: metadata.len(),
365                    max: max_size,
366                }
367            } else {
368                TlsError::KeyTooLarge {
369                    path: path.to_string(),
370                    size: metadata.len(),
371                    max: max_size,
372                }
373            });
374        }
375
376        fs::read_to_string(path).map_err(TlsError::from)
377    }
378
379    /// Reloads all certificates from their original paths.
380    /// This is called on SIGHUP for hot reload.
381    ///
382    /// # Hot Reload Strategy
383    /// Certificates are reloaded atomically: new certificates are loaded into
384    /// temporary maps, then swapped in all at once. If any certificate fails
385    /// to load, all successfully loaded certificates are still applied and
386    /// failures are reported.
387    ///
388    /// # Returns
389    /// `ReloadResult` containing counts of succeeded/failed reloads and error details.
390    pub fn reload_all(&self) -> ReloadResult {
391        info!("Reloading all TLS certificates...");
392
393        let mut result = ReloadResult {
394            succeeded: 0,
395            failed: 0,
396            errors: Vec::new(),
397        };
398
399        // Snapshot current configs to avoid holding lock during reload
400        let configs: Vec<(String, TlsCertConfig)> = {
401            let configs = self.cert_configs.read();
402            configs
403                .iter()
404                .map(|(k, v)| (k.clone(), v.clone()))
405                .collect()
406        };
407
408        let default_config: Option<TlsCertConfig> = { self.default_cert_config.read().clone() };
409
410        if configs.is_empty() && default_config.is_none() {
411            info!("No certificates configured for reload");
412            return result;
413        }
414
415        // Prepare new certificate maps
416        let mut new_exact: HashMap<String, Arc<CertifiedKey>, RandomState> =
417            HashMap::with_hasher(RandomState::new());
418        let mut new_wildcard: HashMap<String, Arc<CertifiedKey>, RandomState> =
419            HashMap::with_hasher(RandomState::new());
420
421        // Reload each certificate
422        for (storage_key, config) in configs {
423            match self.load_cert_internal(&config) {
424                Ok(certified_key) => {
425                    if config.is_wildcard {
426                        new_wildcard.insert(storage_key, certified_key);
427                    } else {
428                        new_exact.insert(storage_key, certified_key);
429                    }
430                    result.succeeded += 1;
431                    debug!("Reloaded certificate for {}", config.domain);
432                }
433                Err(e) => {
434                    result.failed += 1;
435                    result.errors.push((config.domain.clone(), e.to_string()));
436                    warn!("Failed to reload certificate for {}: {}", config.domain, e);
437                }
438            }
439        }
440
441        // Reload default certificate
442        let new_default = if let Some(config) = default_config {
443            match self.load_cert_internal(&config) {
444                Ok(certified_key) => {
445                    result.succeeded += 1;
446                    debug!("Reloaded default certificate for {}", config.domain);
447                    Some(certified_key)
448                }
449                Err(e) => {
450                    result.failed += 1;
451                    result
452                        .errors
453                        .push((format!("default:{}", config.domain), e.to_string()));
454                    warn!(
455                        "Failed to reload default certificate for {}: {}",
456                        config.domain, e
457                    );
458                    None
459                }
460            }
461        } else {
462            None
463        };
464
465        // Atomic swap: apply all successfully loaded certificates
466        if result.succeeded > 0 {
467            // Swap exact certs
468            if !new_exact.is_empty() {
469                let mut exact = self.exact_certs.write();
470                for (key, cert) in new_exact {
471                    exact.insert(key, cert);
472                }
473            }
474
475            // Swap wildcard certs
476            if !new_wildcard.is_empty() {
477                let mut wildcards = self.wildcard_certs.write();
478                for (key, cert) in new_wildcard {
479                    wildcards.insert(key, cert);
480                }
481            }
482
483            // Swap default cert
484            if let Some(cert) = new_default {
485                *self.default_cert.write() = Some(cert);
486            }
487        }
488
489        if result.is_success() {
490            info!("Successfully reloaded {} certificate(s)", result.succeeded);
491        } else {
492            warn!(
493                "Certificate reload completed: {} succeeded, {} failed",
494                result.succeeded, result.failed
495            );
496        }
497
498        result
499    }
500
501    /// Reloads a single certificate by domain.
502    ///
503    /// # Arguments
504    /// * `domain` - The domain to reload (case-insensitive)
505    ///
506    /// # Returns
507    /// `Ok(())` if successful, or the error that occurred.
508    pub fn reload_cert(&self, domain: &str) -> Result<(), TlsError> {
509        let normalized = domain.to_lowercase();
510        let storage_key = normalized.trim_start_matches("*.");
511
512        // Find the config
513        let config = {
514            let configs = self.cert_configs.read();
515            configs.get(storage_key).cloned()
516        };
517
518        let config = config.ok_or_else(|| TlsError::NoCertificate {
519            domain: domain.to_string(),
520        })?;
521
522        // Reload the certificate
523        let certified_key = self.load_cert_internal(&config)?;
524
525        // Apply the new certificate
526        if config.is_wildcard {
527            let mut wildcards = self.wildcard_certs.write();
528            wildcards.insert(storage_key.to_string(), certified_key);
529        } else {
530            let mut exact = self.exact_certs.write();
531            exact.insert(storage_key.to_string(), certified_key);
532        }
533
534        info!("Reloaded certificate for {}", domain);
535        Ok(())
536    }
537
538    /// Internal helper to load a certificate from config without storing it.
539    fn load_cert_internal(&self, config: &TlsCertConfig) -> Result<Arc<CertifiedKey>, TlsError> {
540        // Validate paths for traversal
541        Self::validate_path(&config.cert_path)?;
542        Self::validate_path(&config.key_path)?;
543
544        // Load certificate
545        let cert_pem = Self::read_file_secure(&config.cert_path, MAX_CERT_SIZE, "certificate")?;
546
547        // Load private key
548        let key_pem = Self::read_file_secure(&config.key_path, MAX_CERT_SIZE, "key")?;
549
550        // Create certified key
551        Ok(Arc::new(CertifiedKey {
552            cert_pem: Arc::new(cert_pem),
553            key_pem: Arc::new(SecureString::new(key_pem)),
554            domain: config.domain.clone(),
555        }))
556    }
557
558    /// Returns the list of configured domains (for monitoring/diagnostics).
559    pub fn configured_domains(&self) -> Vec<String> {
560        let configs = self.cert_configs.read();
561        configs.keys().cloned().collect()
562    }
563
564    /// Returns true if a certificate is configured for the given domain.
565    pub fn has_cert_config(&self, domain: &str) -> bool {
566        let normalized = domain.to_lowercase();
567        let storage_key = normalized.trim_start_matches("*.");
568        let configs = self.cert_configs.read();
569        configs.contains_key(storage_key)
570    }
571
572    /// Returns the minimum TLS version.
573    pub fn min_version(&self) -> TlsVersion {
574        self.min_version
575    }
576
577    /// Returns the number of loaded certificates.
578    pub fn cert_count(&self) -> usize {
579        self.exact_certs.read().len() + self.wildcard_certs.read().len()
580    }
581}
582
583impl Default for TlsManager {
584    fn default() -> Self {
585        Self::with_tls12_minimum()
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592    use std::io::Write;
593    use tempfile::NamedTempFile;
594
595    fn create_temp_file(content: &str) -> NamedTempFile {
596        let mut file = NamedTempFile::new().unwrap();
597        file.write_all(content.as_bytes()).unwrap();
598        file
599    }
600
601    const DUMMY_CERT: &str = "-----BEGIN CERTIFICATE-----\nMIIB...\n-----END CERTIFICATE-----";
602    const DUMMY_KEY: &str = "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----";
603
604    #[test]
605    fn test_load_exact_cert() {
606        let cert_file = create_temp_file(DUMMY_CERT);
607        let key_file = create_temp_file(DUMMY_KEY);
608
609        let manager = TlsManager::default();
610        let config = TlsCertConfig {
611            domain: "example.com".to_string(),
612            cert_path: cert_file.path().to_string_lossy().to_string(),
613            key_path: key_file.path().to_string_lossy().to_string(),
614            is_wildcard: false,
615        };
616
617        manager.load_cert(&config).unwrap();
618        assert!(manager.get_cert("example.com").is_some());
619        assert!(manager.get_cert("other.com").is_none());
620    }
621
622    #[test]
623    fn test_load_wildcard_cert() {
624        let cert_file = create_temp_file(DUMMY_CERT);
625        let key_file = create_temp_file(DUMMY_KEY);
626
627        let manager = TlsManager::default();
628        let config = TlsCertConfig {
629            domain: "*.example.com".to_string(),
630            cert_path: cert_file.path().to_string_lossy().to_string(),
631            key_path: key_file.path().to_string_lossy().to_string(),
632            is_wildcard: true,
633        };
634
635        manager.load_cert(&config).unwrap();
636
637        // Wildcard should match subdomains
638        assert!(manager.get_cert("api.example.com").is_some());
639        assert!(manager.get_cert("www.example.com").is_some());
640
641        // Should not match the bare domain or other domains
642        assert!(manager.get_cert("example.com").is_none());
643        assert!(manager.get_cert("other.com").is_none());
644    }
645
646    #[test]
647    fn test_default_cert() {
648        let cert_file = create_temp_file(DUMMY_CERT);
649        let key_file = create_temp_file(DUMMY_KEY);
650
651        let manager = TlsManager::default();
652        let config = TlsCertConfig {
653            domain: "default.local".to_string(),
654            cert_path: cert_file.path().to_string_lossy().to_string(),
655            key_path: key_file.path().to_string_lossy().to_string(),
656            is_wildcard: false,
657        };
658
659        manager.set_default_cert(&config).unwrap();
660
661        // Any unmatched domain should get the default
662        assert!(manager.get_cert("random.com").is_some());
663        assert!(manager.get_cert("anything.org").is_some());
664    }
665
666    #[test]
667    fn test_case_insensitive() {
668        let cert_file = create_temp_file(DUMMY_CERT);
669        let key_file = create_temp_file(DUMMY_KEY);
670
671        let manager = TlsManager::default();
672        let config = TlsCertConfig {
673            domain: "Example.COM".to_string(),
674            cert_path: cert_file.path().to_string_lossy().to_string(),
675            key_path: key_file.path().to_string_lossy().to_string(),
676            is_wildcard: false,
677        };
678
679        manager.load_cert(&config).unwrap();
680
681        assert!(manager.get_cert("example.com").is_some());
682        assert!(manager.get_cert("EXAMPLE.COM").is_some());
683    }
684
685    #[test]
686    fn test_path_traversal() {
687        let manager = TlsManager::default();
688        let config = TlsCertConfig {
689            domain: "example.com".to_string(),
690            cert_path: "../../../etc/passwd".to_string(),
691            key_path: "key.pem".to_string(),
692            is_wildcard: false,
693        };
694
695        let result = manager.load_cert(&config);
696        assert!(matches!(result, Err(TlsError::PathTraversal { .. })));
697    }
698
699    #[test]
700    fn test_cert_not_found() {
701        let key_file = create_temp_file(DUMMY_KEY);
702
703        let manager = TlsManager::default();
704        let config = TlsCertConfig {
705            domain: "example.com".to_string(),
706            cert_path: "/nonexistent/cert.pem".to_string(),
707            key_path: key_file.path().to_string_lossy().to_string(),
708            is_wildcard: false,
709        };
710
711        let result = manager.load_cert(&config);
712        assert!(matches!(result, Err(TlsError::CertNotFound { .. })));
713    }
714
715    #[test]
716    fn test_tls_version_parsing() {
717        assert_eq!(TlsVersion::from_str("1.2").unwrap(), TlsVersion::Tls12);
718        assert_eq!(TlsVersion::from_str("1.3").unwrap(), TlsVersion::Tls13);
719        assert_eq!(TlsVersion::from_str("TLSv1.2").unwrap(), TlsVersion::Tls12);
720        assert!(TlsVersion::from_str("1.1").is_err());
721    }
722
723    #[test]
724    fn test_debug_redacts_key() {
725        let cert = CertifiedKey {
726            cert_pem: Arc::new("cert content".to_string()),
727            key_pem: Arc::new(SecureString::new("secret key".to_string())),
728            domain: "example.com".to_string(),
729        };
730
731        let debug_output = format!("{:?}", cert);
732        assert!(debug_output.contains("REDACTED"));
733        assert!(!debug_output.contains("secret key"));
734    }
735
736    #[test]
737    fn test_cert_count() {
738        let cert_file = create_temp_file(DUMMY_CERT);
739        let key_file = create_temp_file(DUMMY_KEY);
740
741        let manager = TlsManager::default();
742        assert_eq!(manager.cert_count(), 0);
743
744        let config = TlsCertConfig {
745            domain: "example.com".to_string(),
746            cert_path: cert_file.path().to_string_lossy().to_string(),
747            key_path: key_file.path().to_string_lossy().to_string(),
748            is_wildcard: false,
749        };
750
751        manager.load_cert(&config).unwrap();
752        assert_eq!(manager.cert_count(), 1);
753    }
754
755    // ==================== Hot Reload Tests ====================
756
757    #[test]
758    fn test_reload_all_empty() {
759        let manager = TlsManager::default();
760        let result = manager.reload_all();
761
762        assert_eq!(result.succeeded, 0);
763        assert_eq!(result.failed, 0);
764        assert!(result.is_success());
765        assert!(result.errors.is_empty());
766    }
767
768    #[test]
769    fn test_reload_all_success() {
770        let cert_file = create_temp_file(DUMMY_CERT);
771        let key_file = create_temp_file(DUMMY_KEY);
772
773        let manager = TlsManager::default();
774        let config = TlsCertConfig {
775            domain: "example.com".to_string(),
776            cert_path: cert_file.path().to_string_lossy().to_string(),
777            key_path: key_file.path().to_string_lossy().to_string(),
778            is_wildcard: false,
779        };
780
781        manager.load_cert(&config).unwrap();
782
783        // Reload should succeed
784        let result = manager.reload_all();
785        assert_eq!(result.succeeded, 1);
786        assert_eq!(result.failed, 0);
787        assert!(result.is_success());
788
789        // Certificate should still be available
790        assert!(manager.get_cert("example.com").is_some());
791    }
792
793    #[test]
794    fn test_reload_all_multiple_certs() {
795        let cert_file1 = create_temp_file(DUMMY_CERT);
796        let key_file1 = create_temp_file(DUMMY_KEY);
797        let cert_file2 = create_temp_file(DUMMY_CERT);
798        let key_file2 = create_temp_file(DUMMY_KEY);
799
800        let manager = TlsManager::default();
801
802        // Load exact cert
803        manager
804            .load_cert(&TlsCertConfig {
805                domain: "example.com".to_string(),
806                cert_path: cert_file1.path().to_string_lossy().to_string(),
807                key_path: key_file1.path().to_string_lossy().to_string(),
808                is_wildcard: false,
809            })
810            .unwrap();
811
812        // Load wildcard cert
813        manager
814            .load_cert(&TlsCertConfig {
815                domain: "*.other.com".to_string(),
816                cert_path: cert_file2.path().to_string_lossy().to_string(),
817                key_path: key_file2.path().to_string_lossy().to_string(),
818                is_wildcard: true,
819            })
820            .unwrap();
821
822        let result = manager.reload_all();
823        assert_eq!(result.succeeded, 2);
824        assert_eq!(result.failed, 0);
825        assert!(result.is_success());
826
827        // Both certificates should still work
828        assert!(manager.get_cert("example.com").is_some());
829        assert!(manager.get_cert("api.other.com").is_some());
830    }
831
832    #[test]
833    fn test_reload_all_with_default() {
834        let cert_file = create_temp_file(DUMMY_CERT);
835        let key_file = create_temp_file(DUMMY_KEY);
836        let default_cert = create_temp_file(DUMMY_CERT);
837        let default_key = create_temp_file(DUMMY_KEY);
838
839        let manager = TlsManager::default();
840
841        manager
842            .load_cert(&TlsCertConfig {
843                domain: "example.com".to_string(),
844                cert_path: cert_file.path().to_string_lossy().to_string(),
845                key_path: key_file.path().to_string_lossy().to_string(),
846                is_wildcard: false,
847            })
848            .unwrap();
849
850        manager
851            .set_default_cert(&TlsCertConfig {
852                domain: "default.local".to_string(),
853                cert_path: default_cert.path().to_string_lossy().to_string(),
854                key_path: default_key.path().to_string_lossy().to_string(),
855                is_wildcard: false,
856            })
857            .unwrap();
858
859        let result = manager.reload_all();
860        assert_eq!(result.succeeded, 2); // 1 exact + 1 default
861        assert_eq!(result.failed, 0);
862    }
863
864    #[test]
865    fn test_reload_all_partial_failure() {
866        let cert_file = create_temp_file(DUMMY_CERT);
867        let key_file = create_temp_file(DUMMY_KEY);
868
869        let manager = TlsManager::default();
870
871        // Load valid cert
872        manager
873            .load_cert(&TlsCertConfig {
874                domain: "valid.com".to_string(),
875                cert_path: cert_file.path().to_string_lossy().to_string(),
876                key_path: key_file.path().to_string_lossy().to_string(),
877                is_wildcard: false,
878            })
879            .unwrap();
880
881        // Manually insert a config with invalid paths (simulating file deletion)
882        {
883            let mut configs = manager.cert_configs.write();
884            configs.insert(
885                "invalid.com".to_string(),
886                TlsCertConfig {
887                    domain: "invalid.com".to_string(),
888                    cert_path: "/nonexistent/cert.pem".to_string(),
889                    key_path: "/nonexistent/key.pem".to_string(),
890                    is_wildcard: false,
891                },
892            );
893        }
894
895        let result = manager.reload_all();
896        assert_eq!(result.succeeded, 1);
897        assert_eq!(result.failed, 1);
898        assert!(!result.is_success());
899        assert_eq!(result.errors.len(), 1);
900        assert!(result.errors[0].0.contains("invalid.com"));
901
902        // Valid cert should still be reloaded
903        assert!(manager.get_cert("valid.com").is_some());
904    }
905
906    #[test]
907    fn test_reload_single_cert() {
908        let cert_file = create_temp_file(DUMMY_CERT);
909        let key_file = create_temp_file(DUMMY_KEY);
910
911        let manager = TlsManager::default();
912        let config = TlsCertConfig {
913            domain: "example.com".to_string(),
914            cert_path: cert_file.path().to_string_lossy().to_string(),
915            key_path: key_file.path().to_string_lossy().to_string(),
916            is_wildcard: false,
917        };
918
919        manager.load_cert(&config).unwrap();
920
921        // Reload single cert
922        let result = manager.reload_cert("example.com");
923        assert!(result.is_ok());
924        assert!(manager.get_cert("example.com").is_some());
925    }
926
927    #[test]
928    fn test_reload_single_cert_case_insensitive() {
929        let cert_file = create_temp_file(DUMMY_CERT);
930        let key_file = create_temp_file(DUMMY_KEY);
931
932        let manager = TlsManager::default();
933        let config = TlsCertConfig {
934            domain: "Example.COM".to_string(),
935            cert_path: cert_file.path().to_string_lossy().to_string(),
936            key_path: key_file.path().to_string_lossy().to_string(),
937            is_wildcard: false,
938        };
939
940        manager.load_cert(&config).unwrap();
941
942        // Reload with different case
943        assert!(manager.reload_cert("EXAMPLE.com").is_ok());
944    }
945
946    #[test]
947    fn test_reload_single_cert_not_found() {
948        let manager = TlsManager::default();
949
950        let result = manager.reload_cert("notfound.com");
951        assert!(matches!(result, Err(TlsError::NoCertificate { .. })));
952    }
953
954    #[test]
955    fn test_reload_wildcard_cert() {
956        let cert_file = create_temp_file(DUMMY_CERT);
957        let key_file = create_temp_file(DUMMY_KEY);
958
959        let manager = TlsManager::default();
960        let config = TlsCertConfig {
961            domain: "*.example.com".to_string(),
962            cert_path: cert_file.path().to_string_lossy().to_string(),
963            key_path: key_file.path().to_string_lossy().to_string(),
964            is_wildcard: true,
965        };
966
967        manager.load_cert(&config).unwrap();
968
969        // Reload wildcard cert
970        let result = manager.reload_cert("*.example.com");
971        assert!(result.is_ok());
972        assert!(manager.get_cert("api.example.com").is_some());
973    }
974
975    #[test]
976    fn test_configured_domains() {
977        let cert_file1 = create_temp_file(DUMMY_CERT);
978        let key_file1 = create_temp_file(DUMMY_KEY);
979        let cert_file2 = create_temp_file(DUMMY_CERT);
980        let key_file2 = create_temp_file(DUMMY_KEY);
981
982        let manager = TlsManager::default();
983        assert!(manager.configured_domains().is_empty());
984
985        manager
986            .load_cert(&TlsCertConfig {
987                domain: "one.com".to_string(),
988                cert_path: cert_file1.path().to_string_lossy().to_string(),
989                key_path: key_file1.path().to_string_lossy().to_string(),
990                is_wildcard: false,
991            })
992            .unwrap();
993
994        manager
995            .load_cert(&TlsCertConfig {
996                domain: "*.two.com".to_string(),
997                cert_path: cert_file2.path().to_string_lossy().to_string(),
998                key_path: key_file2.path().to_string_lossy().to_string(),
999                is_wildcard: true,
1000            })
1001            .unwrap();
1002
1003        let domains = manager.configured_domains();
1004        assert_eq!(domains.len(), 2);
1005        assert!(domains.contains(&"one.com".to_string()));
1006        assert!(domains.contains(&"two.com".to_string())); // Wildcard stored by base domain
1007    }
1008
1009    #[test]
1010    fn test_has_cert_config() {
1011        let cert_file = create_temp_file(DUMMY_CERT);
1012        let key_file = create_temp_file(DUMMY_KEY);
1013
1014        let manager = TlsManager::default();
1015        assert!(!manager.has_cert_config("example.com"));
1016
1017        manager
1018            .load_cert(&TlsCertConfig {
1019                domain: "example.com".to_string(),
1020                cert_path: cert_file.path().to_string_lossy().to_string(),
1021                key_path: key_file.path().to_string_lossy().to_string(),
1022                is_wildcard: false,
1023            })
1024            .unwrap();
1025
1026        assert!(manager.has_cert_config("example.com"));
1027        assert!(manager.has_cert_config("EXAMPLE.COM")); // Case insensitive
1028        assert!(!manager.has_cert_config("other.com"));
1029    }
1030
1031    #[test]
1032    fn test_reload_updates_cert_content() {
1033        use std::io::{Seek, SeekFrom};
1034
1035        let mut cert_file = NamedTempFile::new().unwrap();
1036        let mut key_file = NamedTempFile::new().unwrap();
1037
1038        // Write initial content
1039        cert_file.write_all(DUMMY_CERT.as_bytes()).unwrap();
1040        key_file.write_all(DUMMY_KEY.as_bytes()).unwrap();
1041
1042        let manager = TlsManager::default();
1043        let config = TlsCertConfig {
1044            domain: "example.com".to_string(),
1045            cert_path: cert_file.path().to_string_lossy().to_string(),
1046            key_path: key_file.path().to_string_lossy().to_string(),
1047            is_wildcard: false,
1048        };
1049
1050        manager.load_cert(&config).unwrap();
1051
1052        // Get initial cert
1053        let cert1 = manager.get_cert("example.com").unwrap();
1054        let initial_cert = cert1.cert_pem.clone();
1055
1056        // Update cert file with new content - use as_file_mut() to get mutable file handle
1057        let new_cert = "-----BEGIN CERTIFICATE-----\nNEW_CERT\n-----END CERTIFICATE-----";
1058        {
1059            let file = cert_file.as_file_mut();
1060            file.seek(SeekFrom::Start(0)).unwrap();
1061            file.set_len(0).unwrap();
1062        }
1063        cert_file.write_all(new_cert.as_bytes()).unwrap();
1064
1065        // Reload
1066        manager.reload_cert("example.com").unwrap();
1067
1068        // Verify cert was updated
1069        let cert2 = manager.get_cert("example.com").unwrap();
1070        assert_ne!(*initial_cert, *cert2.cert_pem);
1071        assert!(cert2.cert_pem.contains("NEW_CERT"));
1072    }
1073
1074    #[test]
1075    fn test_reload_result_debug() {
1076        let result = ReloadResult {
1077            succeeded: 5,
1078            failed: 2,
1079            errors: vec![
1080                ("domain1.com".to_string(), "file not found".to_string()),
1081                ("domain2.com".to_string(), "permission denied".to_string()),
1082            ],
1083        };
1084
1085        let debug_output = format!("{:?}", result);
1086        assert!(debug_output.contains("succeeded: 5"));
1087        assert!(debug_output.contains("failed: 2"));
1088        assert!(debug_output.contains("domain1.com"));
1089    }
1090}