sentinel_proxy/
tls.rs

1//! TLS Configuration and SNI Support
2//!
3//! This module provides TLS configuration with Server Name Indication (SNI) support
4//! for serving multiple certificates based on the requested hostname.
5//!
6//! # Features
7//!
8//! - SNI-based certificate selection
9//! - Wildcard certificate matching (e.g., `*.example.com`)
10//! - Default certificate fallback
11//! - Certificate validation at startup
12//! - mTLS client certificate verification
13//! - Certificate hot-reload on SIGHUP
14//! - OCSP stapling support
15//!
16//! # Example KDL Configuration
17//!
18//! ```kdl
19//! listener "https" {
20//!     address "0.0.0.0:443"
21//!     protocol "https"
22//!     tls {
23//!         cert-file "/etc/certs/default.crt"
24//!         key-file "/etc/certs/default.key"
25//!
26//!         // SNI certificates
27//!         sni {
28//!             hostnames "example.com" "www.example.com"
29//!             cert-file "/etc/certs/example.crt"
30//!             key-file "/etc/certs/example.key"
31//!         }
32//!         sni {
33//!             hostnames "*.api.example.com"
34//!             cert-file "/etc/certs/api-wildcard.crt"
35//!             key-file "/etc/certs/api-wildcard.key"
36//!         }
37//!
38//!         // mTLS configuration
39//!         ca-file "/etc/certs/ca.crt"
40//!         client-auth true
41//!
42//!         // OCSP stapling
43//!         ocsp-stapling true
44//!     }
45//! }
46//! ```
47
48use std::collections::HashMap;
49use std::fs::File;
50use std::io::BufReader;
51use std::path::Path;
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54
55use parking_lot::RwLock;
56use rustls::client::ClientConfig;
57use rustls::pki_types::CertificateDer;
58use rustls::server::{ClientHello, ResolvesServerCert};
59use rustls::sign::CertifiedKey;
60use rustls::{RootCertStore, ServerConfig};
61use tracing::{debug, error, info, trace, warn};
62
63use sentinel_config::{TlsConfig, UpstreamTlsConfig};
64
65/// Error type for TLS operations
66#[derive(Debug)]
67pub enum TlsError {
68    /// Failed to load certificate file
69    CertificateLoad(String),
70    /// Failed to load private key file
71    KeyLoad(String),
72    /// Failed to build TLS configuration
73    ConfigBuild(String),
74    /// Certificate/key mismatch
75    CertKeyMismatch(String),
76    /// Invalid certificate
77    InvalidCertificate(String),
78    /// OCSP fetch error
79    OcspFetch(String),
80}
81
82impl std::fmt::Display for TlsError {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        match self {
85            TlsError::CertificateLoad(e) => write!(f, "Failed to load certificate: {}", e),
86            TlsError::KeyLoad(e) => write!(f, "Failed to load private key: {}", e),
87            TlsError::ConfigBuild(e) => write!(f, "Failed to build TLS config: {}", e),
88            TlsError::CertKeyMismatch(e) => write!(f, "Certificate/key mismatch: {}", e),
89            TlsError::InvalidCertificate(e) => write!(f, "Invalid certificate: {}", e),
90            TlsError::OcspFetch(e) => write!(f, "Failed to fetch OCSP response: {}", e),
91        }
92    }
93}
94
95impl std::error::Error for TlsError {}
96
97/// SNI-aware certificate resolver
98///
99/// Resolves certificates based on the Server Name Indication (SNI) extension
100/// in the TLS handshake. Supports:
101/// - Exact hostname matches
102/// - Wildcard certificates (e.g., `*.example.com`)
103/// - Default certificate fallback
104#[derive(Debug)]
105pub struct SniResolver {
106    /// Default certificate (used when no SNI match)
107    default_cert: Arc<CertifiedKey>,
108    /// SNI hostname to certificate mapping
109    /// Key is lowercase hostname, value is the certified key
110    sni_certs: HashMap<String, Arc<CertifiedKey>>,
111    /// Wildcard certificates (e.g., "*.example.com" -> cert)
112    wildcard_certs: HashMap<String, Arc<CertifiedKey>>,
113}
114
115impl SniResolver {
116    /// Create a new SNI resolver from TLS configuration
117    pub fn from_config(config: &TlsConfig) -> Result<Self, TlsError> {
118        // Get cert_file and key_file - required for non-ACME configs
119        let (cert_file, key_file) = match (&config.cert_file, &config.key_file) {
120            (Some(cert), Some(key)) => (cert, key),
121            _ => {
122                return Err(TlsError::ConfigBuild(
123                    "TLS configuration requires cert_file and key_file".to_string(),
124                ));
125            }
126        };
127
128        // Load default certificate
129        let default_cert = load_certified_key(cert_file, key_file)?;
130
131        info!(
132            cert_file = %cert_file.display(),
133            "Loaded default TLS certificate"
134        );
135
136        let mut sni_certs = HashMap::new();
137        let mut wildcard_certs = HashMap::new();
138
139        // Load SNI certificates
140        for sni_config in &config.additional_certs {
141            let cert = load_certified_key(&sni_config.cert_file, &sni_config.key_file)?;
142            let cert = Arc::new(cert);
143
144            for hostname in &sni_config.hostnames {
145                let hostname_lower = hostname.to_lowercase();
146
147                if hostname_lower.starts_with("*.") {
148                    // Wildcard certificate
149                    let domain = hostname_lower.strip_prefix("*.").unwrap().to_string();
150                    wildcard_certs.insert(domain.clone(), cert.clone());
151                    debug!(
152                        pattern = %hostname,
153                        domain = %domain,
154                        cert_file = %sni_config.cert_file.display(),
155                        "Registered wildcard SNI certificate"
156                    );
157                } else {
158                    // Exact hostname match
159                    sni_certs.insert(hostname_lower.clone(), cert.clone());
160                    debug!(
161                        hostname = %hostname_lower,
162                        cert_file = %sni_config.cert_file.display(),
163                        "Registered SNI certificate"
164                    );
165                }
166            }
167        }
168
169        info!(
170            exact_certs = sni_certs.len(),
171            wildcard_certs = wildcard_certs.len(),
172            "SNI resolver initialized"
173        );
174
175        Ok(Self {
176            default_cert: Arc::new(default_cert),
177            sni_certs,
178            wildcard_certs,
179        })
180    }
181
182    /// Resolve certificate for a given server name
183    ///
184    /// This is the core resolution logic. For the rustls trait implementation,
185    /// see `ResolvesServerCert`.
186    pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
187        let Some(name) = server_name else {
188            debug!("No SNI provided, using default certificate");
189            return self.default_cert.clone();
190        };
191
192        let name_lower = name.to_lowercase();
193
194        // Try exact match first
195        if let Some(cert) = self.sni_certs.get(&name_lower) {
196            debug!(hostname = %name_lower, "SNI exact match found");
197            return cert.clone();
198        }
199
200        // Try wildcard match
201        // For "foo.bar.example.com", try "bar.example.com", then "example.com"
202        let parts: Vec<&str> = name_lower.split('.').collect();
203        for i in 1..parts.len() {
204            let domain = parts[i..].join(".");
205            if let Some(cert) = self.wildcard_certs.get(&domain) {
206                debug!(
207                    hostname = %name_lower,
208                    wildcard_domain = %domain,
209                    "SNI wildcard match found"
210                );
211                return cert.clone();
212            }
213        }
214
215        debug!(
216            hostname = %name_lower,
217            "No SNI match found, using default certificate"
218        );
219        self.default_cert.clone()
220    }
221}
222
223impl ResolvesServerCert for SniResolver {
224    fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
225        Some(self.resolve(client_hello.server_name()))
226    }
227}
228
229// ============================================================================
230// Hot-Reloadable Certificate Support
231// ============================================================================
232
233/// Hot-reloadable SNI certificate resolver
234///
235/// Wraps an SniResolver behind an RwLock to allow certificate hot-reload
236/// without restarting the server. On SIGHUP, the inner resolver is replaced
237/// with a newly loaded one.
238pub struct HotReloadableSniResolver {
239    /// Inner resolver (protected by RwLock for hot-reload)
240    inner: RwLock<Arc<SniResolver>>,
241    /// Original config for reloading
242    config: RwLock<TlsConfig>,
243    /// Last reload time
244    last_reload: RwLock<Instant>,
245}
246
247impl std::fmt::Debug for HotReloadableSniResolver {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        f.debug_struct("HotReloadableSniResolver")
250            .field("last_reload", &*self.last_reload.read())
251            .finish()
252    }
253}
254
255impl HotReloadableSniResolver {
256    /// Create a new hot-reloadable resolver from TLS configuration
257    pub fn from_config(config: TlsConfig) -> Result<Self, TlsError> {
258        let resolver = SniResolver::from_config(&config)?;
259
260        Ok(Self {
261            inner: RwLock::new(Arc::new(resolver)),
262            config: RwLock::new(config),
263            last_reload: RwLock::new(Instant::now()),
264        })
265    }
266
267    /// Reload certificates from disk
268    ///
269    /// This is called on SIGHUP to pick up new certificates without restart.
270    /// If the reload fails, the old certificates continue to be used.
271    pub fn reload(&self) -> Result<(), TlsError> {
272        let config = self.config.read();
273
274        let cert_file_display = config
275            .cert_file
276            .as_ref()
277            .map(|p| p.display().to_string())
278            .unwrap_or_else(|| "(acme-managed)".to_string());
279
280        info!(
281            cert_file = %cert_file_display,
282            sni_count = config.additional_certs.len(),
283            "Reloading TLS certificates"
284        );
285
286        // Try to load new certificates
287        let new_resolver = SniResolver::from_config(&config)?;
288
289        // Swap in the new resolver atomically
290        *self.inner.write() = Arc::new(new_resolver);
291        *self.last_reload.write() = Instant::now();
292
293        info!("TLS certificates reloaded successfully");
294        Ok(())
295    }
296
297    /// Update configuration and reload
298    pub fn update_config(&self, new_config: TlsConfig) -> Result<(), TlsError> {
299        // Load with new config first
300        let new_resolver = SniResolver::from_config(&new_config)?;
301
302        // Update both config and resolver
303        *self.config.write() = new_config;
304        *self.inner.write() = Arc::new(new_resolver);
305        *self.last_reload.write() = Instant::now();
306
307        info!("TLS configuration updated and certificates reloaded");
308        Ok(())
309    }
310
311    /// Get time since last reload
312    pub fn last_reload_age(&self) -> Duration {
313        self.last_reload.read().elapsed()
314    }
315
316    /// Resolve certificate for a given server name
317    ///
318    /// This is the core resolution logic exposed for testing.
319    pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
320        self.inner.read().resolve(server_name)
321    }
322}
323
324impl ResolvesServerCert for HotReloadableSniResolver {
325    fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
326        Some(self.inner.read().resolve(client_hello.server_name()))
327    }
328}
329
330/// Certificate reload manager
331///
332/// Tracks all TLS listeners and provides a unified reload interface.
333pub struct CertificateReloader {
334    /// Map of listener ID to hot-reloadable resolver
335    resolvers: RwLock<HashMap<String, Arc<HotReloadableSniResolver>>>,
336}
337
338impl CertificateReloader {
339    /// Create a new certificate reloader
340    pub fn new() -> Self {
341        Self {
342            resolvers: RwLock::new(HashMap::new()),
343        }
344    }
345
346    /// Register a resolver for a listener
347    pub fn register(&self, listener_id: &str, resolver: Arc<HotReloadableSniResolver>) {
348        debug!(listener_id = %listener_id, "Registering TLS resolver for hot-reload");
349        self.resolvers
350            .write()
351            .insert(listener_id.to_string(), resolver);
352    }
353
354    /// Reload all registered certificates
355    ///
356    /// Returns the number of successfully reloaded listeners and any errors.
357    pub fn reload_all(&self) -> (usize, Vec<(String, TlsError)>) {
358        let resolvers = self.resolvers.read();
359        let mut success_count = 0;
360        let mut errors = Vec::new();
361
362        info!(
363            listener_count = resolvers.len(),
364            "Reloading certificates for all TLS listeners"
365        );
366
367        for (listener_id, resolver) in resolvers.iter() {
368            match resolver.reload() {
369                Ok(()) => {
370                    success_count += 1;
371                    debug!(listener_id = %listener_id, "Certificate reload successful");
372                }
373                Err(e) => {
374                    error!(listener_id = %listener_id, error = %e, "Certificate reload failed");
375                    errors.push((listener_id.clone(), e));
376                }
377            }
378        }
379
380        if errors.is_empty() {
381            info!(
382                success_count = success_count,
383                "All certificates reloaded successfully"
384            );
385        } else {
386            warn!(
387                success_count = success_count,
388                error_count = errors.len(),
389                "Certificate reload completed with errors"
390            );
391        }
392
393        (success_count, errors)
394    }
395
396    /// Get reload status for all listeners
397    pub fn status(&self) -> HashMap<String, Duration> {
398        self.resolvers
399            .read()
400            .iter()
401            .map(|(id, resolver)| (id.clone(), resolver.last_reload_age()))
402            .collect()
403    }
404}
405
406impl Default for CertificateReloader {
407    fn default() -> Self {
408        Self::new()
409    }
410}
411
412// ============================================================================
413// OCSP Stapling Support
414// ============================================================================
415
416/// OCSP response cache entry
417#[derive(Debug, Clone)]
418pub struct OcspCacheEntry {
419    /// DER-encoded OCSP response
420    pub response: Vec<u8>,
421    /// When this response was fetched
422    pub fetched_at: Instant,
423    /// When this response expires (from nextUpdate field)
424    pub expires_at: Option<Instant>,
425}
426
427/// OCSP stapling manager
428///
429/// Fetches and caches OCSP responses for certificates.
430pub struct OcspStapler {
431    /// Cache of OCSP responses by certificate fingerprint
432    cache: RwLock<HashMap<String, OcspCacheEntry>>,
433    /// Refresh interval for OCSP responses (default 1 hour)
434    refresh_interval: Duration,
435}
436
437impl OcspStapler {
438    /// Create a new OCSP stapler
439    pub fn new() -> Self {
440        Self {
441            cache: RwLock::new(HashMap::new()),
442            refresh_interval: Duration::from_secs(3600), // 1 hour default
443        }
444    }
445
446    /// Create with custom refresh interval
447    pub fn with_refresh_interval(interval: Duration) -> Self {
448        Self {
449            cache: RwLock::new(HashMap::new()),
450            refresh_interval: interval,
451        }
452    }
453
454    /// Get cached OCSP response for a certificate
455    pub fn get_response(&self, cert_fingerprint: &str) -> Option<Vec<u8>> {
456        let cache = self.cache.read();
457        if let Some(entry) = cache.get(cert_fingerprint) {
458            // Check if response is still valid
459            if entry.fetched_at.elapsed() < self.refresh_interval {
460                trace!(fingerprint = %cert_fingerprint, "OCSP cache hit");
461                return Some(entry.response.clone());
462            }
463            trace!(fingerprint = %cert_fingerprint, "OCSP cache expired");
464        }
465        None
466    }
467
468    /// Fetch OCSP response for a certificate
469    ///
470    /// This performs an HTTP request to the OCSP responder specified in the
471    /// certificate's Authority Information Access extension.
472    pub fn fetch_ocsp_response(
473        &self,
474        cert_der: &[u8],
475        issuer_der: &[u8],
476    ) -> Result<Vec<u8>, TlsError> {
477        use x509_parser::prelude::*;
478
479        // Parse the end-entity certificate
480        let (_, cert) = X509Certificate::from_der(cert_der)
481            .map_err(|e| TlsError::OcspFetch(format!("Failed to parse certificate: {}", e)))?;
482
483        // Parse the issuer certificate
484        let (_, issuer) = X509Certificate::from_der(issuer_der)
485            .map_err(|e| TlsError::OcspFetch(format!("Failed to parse issuer certificate: {}", e)))?;
486
487        // Extract OCSP responder URL from AIA extension
488        let ocsp_url = extract_ocsp_responder_url(&cert)?;
489        debug!(url = %ocsp_url, "Found OCSP responder URL");
490
491        // Build OCSP request
492        let ocsp_request = build_ocsp_request(&cert, &issuer)?;
493
494        // Send request synchronously (blocking context)
495        // Note: In production, this should be async with proper timeout handling
496        let response = send_ocsp_request_sync(&ocsp_url, &ocsp_request)?;
497
498        // Calculate fingerprint for caching
499        let fingerprint = calculate_cert_fingerprint(cert_der);
500
501        // Cache the response
502        let entry = OcspCacheEntry {
503            response: response.clone(),
504            fetched_at: Instant::now(),
505            expires_at: None, // Could parse nextUpdate from response
506        };
507        self.cache.write().insert(fingerprint, entry);
508
509        info!("Successfully fetched and cached OCSP response");
510        Ok(response)
511    }
512
513    /// Async version of fetch_ocsp_response
514    pub async fn fetch_ocsp_response_async(
515        &self,
516        cert_der: &[u8],
517        issuer_der: &[u8],
518    ) -> Result<Vec<u8>, TlsError> {
519        use x509_parser::prelude::*;
520
521        // Parse the end-entity certificate
522        let (_, cert) = X509Certificate::from_der(cert_der)
523            .map_err(|e| TlsError::OcspFetch(format!("Failed to parse certificate: {}", e)))?;
524
525        // Parse the issuer certificate
526        let (_, issuer) = X509Certificate::from_der(issuer_der)
527            .map_err(|e| TlsError::OcspFetch(format!("Failed to parse issuer certificate: {}", e)))?;
528
529        // Extract OCSP responder URL from AIA extension
530        let ocsp_url = extract_ocsp_responder_url(&cert)?;
531        debug!(url = %ocsp_url, "Found OCSP responder URL");
532
533        // Build OCSP request
534        let ocsp_request = build_ocsp_request(&cert, &issuer)?;
535
536        // Send request asynchronously
537        let response = send_ocsp_request_async(&ocsp_url, &ocsp_request).await?;
538
539        // Calculate fingerprint for caching
540        let fingerprint = calculate_cert_fingerprint(cert_der);
541
542        // Cache the response
543        let entry = OcspCacheEntry {
544            response: response.clone(),
545            fetched_at: Instant::now(),
546            expires_at: None,
547        };
548        self.cache.write().insert(fingerprint, entry);
549
550        info!("Successfully fetched and cached OCSP response (async)");
551        Ok(response)
552    }
553
554    /// Prefetch OCSP responses for all certificates in a config
555    pub fn prefetch_for_config(&self, config: &TlsConfig) -> Vec<String> {
556        let mut warnings = Vec::new();
557
558        if !config.ocsp_stapling {
559            trace!("OCSP stapling disabled in config");
560            return warnings;
561        }
562
563        info!("Prefetching OCSP responses for certificates");
564
565        // For now, just log that we would prefetch
566        // Full implementation would iterate certificates and fetch OCSP responses
567        warnings.push("OCSP stapling prefetch not yet fully implemented".to_string());
568
569        warnings
570    }
571
572    /// Clear the OCSP cache
573    pub fn clear_cache(&self) {
574        self.cache.write().clear();
575        info!("OCSP cache cleared");
576    }
577}
578
579impl Default for OcspStapler {
580    fn default() -> Self {
581        Self::new()
582    }
583}
584
585// ============================================================================
586// OCSP Helper Functions
587// ============================================================================
588
589/// Extract OCSP responder URL from certificate's Authority Information Access extension
590fn extract_ocsp_responder_url(cert: &x509_parser::certificate::X509Certificate) -> Result<String, TlsError> {
591    use x509_parser::prelude::*;
592
593    // Find the AIA extension
594    let aia = cert
595        .extensions()
596        .iter()
597        .find(|ext| ext.oid == oid_registry::OID_PKIX_AUTHORITY_INFO_ACCESS)
598        .ok_or_else(|| TlsError::OcspFetch(
599            "Certificate does not have Authority Information Access extension".to_string()
600        ))?;
601
602    // Parse AIA extension
603    let aia_value = match aia.parsed_extension() {
604        ParsedExtension::AuthorityInfoAccess(aia) => aia,
605        _ => return Err(TlsError::OcspFetch(
606            "Failed to parse Authority Information Access extension".to_string()
607        )),
608    };
609
610    // Find OCSP access method
611    for access in &aia_value.accessdescs {
612        if access.access_method == oid_registry::OID_PKIX_ACCESS_DESCRIPTOR_OCSP {
613            match &access.access_location {
614                GeneralName::URI(url) => {
615                    return Ok(url.to_string());
616                }
617                _ => continue,
618            }
619        }
620    }
621
622    Err(TlsError::OcspFetch(
623        "Certificate AIA does not contain OCSP responder URL".to_string()
624    ))
625}
626
627/// Build an OCSP request for the given certificate
628///
629/// This builds a minimal OCSP request with SHA-256 hashes
630fn build_ocsp_request(
631    cert: &x509_parser::certificate::X509Certificate,
632    issuer: &x509_parser::certificate::X509Certificate,
633) -> Result<Vec<u8>, TlsError> {
634    use sha2::{Sha256, Digest};
635
636    // Per RFC 6960, an OCSP request contains:
637    // - Hash of issuer name
638    // - Hash of issuer public key
639    // - Certificate serial number
640
641    // Hash issuer name (Distinguished Name)
642    let issuer_name_hash = {
643        let mut hasher = Sha256::new();
644        hasher.update(issuer.subject().as_raw());
645        hasher.finalize()
646    };
647
648    // Hash issuer public key (the BIT STRING content, not including tag/length)
649    let issuer_key_hash = {
650        let mut hasher = Sha256::new();
651        hasher.update(issuer.public_key().subject_public_key.data.as_ref());
652        hasher.finalize()
653    };
654
655    // Get certificate serial number
656    let serial = cert.serial.to_bytes_be();
657
658    // Build ASN.1 DER encoded OCSP request
659    // This is a minimal implementation of the OCSP request structure
660    let request = build_ocsp_request_der(
661        &issuer_name_hash,
662        &issuer_key_hash,
663        &serial,
664    );
665
666    Ok(request)
667}
668
669/// Build DER-encoded OCSP request
670fn build_ocsp_request_der(
671    issuer_name_hash: &[u8],
672    issuer_key_hash: &[u8],
673    serial_number: &[u8],
674) -> Vec<u8> {
675    // OID for SHA-256
676    let sha256_oid: &[u8] = &[0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01];
677
678    // Build CertID structure
679    let hash_algorithm = der_sequence(&[
680        &der_oid(sha256_oid),
681        &der_null(),
682    ]);
683
684    let cert_id = der_sequence(&[
685        &hash_algorithm,
686        &der_octet_string(issuer_name_hash),
687        &der_octet_string(issuer_key_hash),
688        &der_integer(serial_number),
689    ]);
690
691    // Build Request structure
692    let request = der_sequence(&[&cert_id]);
693
694    // Build requestList (SEQUENCE OF Request)
695    let request_list = der_sequence(&[&request]);
696
697    // Build TBSRequest
698    let tbs_request = der_sequence(&[&request_list]);
699
700    // Build OCSPRequest
701    der_sequence(&[&tbs_request])
702}
703
704// DER encoding helpers
705fn der_sequence(items: &[&[u8]]) -> Vec<u8> {
706    let mut content = Vec::new();
707    for item in items {
708        content.extend_from_slice(item);
709    }
710    let mut result = vec![0x30]; // SEQUENCE tag
711    result.extend(der_length(content.len()));
712    result.extend(content);
713    result
714}
715
716fn der_oid(oid: &[u8]) -> Vec<u8> {
717    let mut result = vec![0x06]; // OID tag
718    result.extend(der_length(oid.len()));
719    result.extend_from_slice(oid);
720    result
721}
722
723fn der_null() -> Vec<u8> {
724    vec![0x05, 0x00] // NULL
725}
726
727fn der_octet_string(data: &[u8]) -> Vec<u8> {
728    let mut result = vec![0x04]; // OCTET STRING tag
729    result.extend(der_length(data.len()));
730    result.extend_from_slice(data);
731    result
732}
733
734fn der_integer(data: &[u8]) -> Vec<u8> {
735    let mut result = vec![0x02]; // INTEGER tag
736    // Remove leading zeros but ensure at least one byte
737    let data = match data.iter().position(|&b| b != 0) {
738        Some(pos) => &data[pos..],
739        None => &[0],
740    };
741    // Add leading zero if high bit is set (to ensure positive)
742    if !data.is_empty() && data[0] & 0x80 != 0 {
743        result.extend(der_length(data.len() + 1));
744        result.push(0x00);
745    } else {
746        result.extend(der_length(data.len()));
747    }
748    result.extend_from_slice(data);
749    result
750}
751
752fn der_length(len: usize) -> Vec<u8> {
753    if len < 128 {
754        vec![len as u8]
755    } else if len < 256 {
756        vec![0x81, len as u8]
757    } else {
758        vec![0x82, (len >> 8) as u8, len as u8]
759    }
760}
761
762/// Send OCSP request synchronously (blocking)
763fn send_ocsp_request_sync(url: &str, request: &[u8]) -> Result<Vec<u8>, TlsError> {
764    use std::io::{Read, Write};
765    use std::net::TcpStream;
766    use std::time::Duration;
767
768    // Parse URL to get host, port, and path
769    let url = url::Url::parse(url)
770        .map_err(|e| TlsError::OcspFetch(format!("Invalid OCSP URL: {}", e)))?;
771
772    let host = url.host_str()
773        .ok_or_else(|| TlsError::OcspFetch("OCSP URL has no host".to_string()))?;
774    let port = url.port().unwrap_or(80);
775    let path = if url.path().is_empty() { "/" } else { url.path() };
776
777    // Connect to server
778    let addr = format!("{}:{}", host, port);
779    let mut stream = TcpStream::connect(&addr)
780        .map_err(|e| TlsError::OcspFetch(format!("Failed to connect to OCSP responder: {}", e)))?;
781
782    stream.set_read_timeout(Some(Duration::from_secs(10)))
783        .map_err(|e| TlsError::OcspFetch(format!("Failed to set timeout: {}", e)))?;
784    stream.set_write_timeout(Some(Duration::from_secs(10)))
785        .map_err(|e| TlsError::OcspFetch(format!("Failed to set timeout: {}", e)))?;
786
787    // Build HTTP POST request
788    let http_request = format!(
789        "POST {} HTTP/1.1\r\n\
790         Host: {}\r\n\
791         Content-Type: application/ocsp-request\r\n\
792         Content-Length: {}\r\n\
793         Connection: close\r\n\
794         \r\n",
795        path, host, request.len()
796    );
797
798    // Send request
799    stream.write_all(http_request.as_bytes())
800        .map_err(|e| TlsError::OcspFetch(format!("Failed to send OCSP request: {}", e)))?;
801    stream.write_all(request)
802        .map_err(|e| TlsError::OcspFetch(format!("Failed to send OCSP request body: {}", e)))?;
803
804    // Read response
805    let mut response = Vec::new();
806    stream.read_to_end(&mut response)
807        .map_err(|e| TlsError::OcspFetch(format!("Failed to read OCSP response: {}", e)))?;
808
809    // Parse HTTP response - find body after headers
810    let headers_end = response.windows(4)
811        .position(|w| w == b"\r\n\r\n")
812        .ok_or_else(|| TlsError::OcspFetch("Invalid HTTP response: no headers end".to_string()))?;
813
814    let body = &response[headers_end + 4..];
815    if body.is_empty() {
816        return Err(TlsError::OcspFetch("Empty OCSP response body".to_string()));
817    }
818
819    Ok(body.to_vec())
820}
821
822/// Send OCSP request asynchronously
823async fn send_ocsp_request_async(url: &str, request: &[u8]) -> Result<Vec<u8>, TlsError> {
824    let client = reqwest::Client::builder()
825        .timeout(Duration::from_secs(10))
826        .build()
827        .map_err(|e| TlsError::OcspFetch(format!("Failed to create HTTP client: {}", e)))?;
828
829    let response = client
830        .post(url)
831        .header("Content-Type", "application/ocsp-request")
832        .body(request.to_vec())
833        .send()
834        .await
835        .map_err(|e| TlsError::OcspFetch(format!("OCSP request failed: {}", e)))?;
836
837    if !response.status().is_success() {
838        return Err(TlsError::OcspFetch(format!(
839            "OCSP responder returned status: {}",
840            response.status()
841        )));
842    }
843
844    let body = response.bytes().await
845        .map_err(|e| TlsError::OcspFetch(format!("Failed to read OCSP response: {}", e)))?;
846
847    Ok(body.to_vec())
848}
849
850/// Calculate certificate fingerprint for cache key
851fn calculate_cert_fingerprint(cert_der: &[u8]) -> String {
852    use sha2::{Sha256, Digest};
853    let mut hasher = Sha256::new();
854    hasher.update(cert_der);
855    let result = hasher.finalize();
856    hex::encode(result)
857}
858
859// ============================================================================
860// Upstream mTLS Support (Client Certificates)
861// ============================================================================
862
863/// Load client certificate and key for mTLS to upstreams
864///
865/// This function loads PEM-encoded certificates and private key and converts
866/// them to Pingora's CertKey format for use with `HttpPeer.client_cert_key`.
867///
868/// # Arguments
869///
870/// * `cert_path` - Path to PEM-encoded certificate (may include chain)
871/// * `key_path` - Path to PEM-encoded private key
872///
873/// # Returns
874///
875/// An `Arc<CertKey>` that can be set on `peer.client_cert_key` for mTLS
876pub fn load_client_cert_key(
877    cert_path: &Path,
878    key_path: &Path,
879) -> Result<Arc<pingora_core::utils::tls::CertKey>, TlsError> {
880    // Read certificate chain (PEM format, may contain intermediates)
881    let cert_file = File::open(cert_path)
882        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
883    let mut cert_reader = BufReader::new(cert_file);
884
885    // Parse certificates from PEM to DER
886    let cert_ders: Vec<Vec<u8>> = rustls_pemfile::certs(&mut cert_reader)
887        .collect::<Result<Vec<_>, _>>()
888        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?
889        .into_iter()
890        .map(|c| c.to_vec())
891        .collect();
892
893    if cert_ders.is_empty() {
894        return Err(TlsError::CertificateLoad(format!(
895            "{}: No certificates found in PEM file",
896            cert_path.display()
897        )));
898    }
899
900    // Read private key (PEM format)
901    let key_file = File::open(key_path)
902        .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
903    let mut key_reader = BufReader::new(key_file);
904
905    // Parse private key from PEM to DER
906    let key_der = rustls_pemfile::private_key(&mut key_reader)
907        .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
908        .ok_or_else(|| {
909            TlsError::KeyLoad(format!(
910                "{}: No private key found in PEM file",
911                key_path.display()
912            ))
913        })?
914        .secret_der()
915        .to_vec();
916
917    // Create Pingora's CertKey (certificates: Vec<Vec<u8>>, key: Vec<u8>)
918    let cert_key = pingora_core::utils::tls::CertKey::new(cert_ders, key_der);
919
920    debug!(
921        cert_path = %cert_path.display(),
922        key_path = %key_path.display(),
923        "Loaded mTLS client certificate for upstream connections"
924    );
925
926    Ok(Arc::new(cert_key))
927}
928
929/// Build a TLS client configuration for upstream connections with mTLS
930///
931/// This creates a rustls ClientConfig that can be used when Sentinel
932/// connects to backends that require client certificate authentication.
933pub fn build_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<ClientConfig, TlsError> {
934    let mut root_store = RootCertStore::empty();
935
936    // Load CA certificates for server verification
937    if let Some(ca_path) = &config.ca_cert {
938        let ca_file = File::open(ca_path)
939            .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
940        let mut ca_reader = BufReader::new(ca_file);
941
942        let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
943            .collect::<Result<Vec<_>, _>>()
944            .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
945
946        for cert in certs {
947            root_store.add(cert).map_err(|e| {
948                TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
949            })?;
950        }
951
952        debug!(
953            ca_file = %ca_path.display(),
954            cert_count = root_store.len(),
955            "Loaded upstream CA certificates"
956        );
957    } else if !config.insecure_skip_verify {
958        // Use webpki roots for standard TLS
959        root_store = RootCertStore {
960            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
961        };
962        trace!("Using webpki-roots for upstream TLS verification");
963    }
964
965    // Build the client config
966    let builder = ClientConfig::builder().with_root_certificates(root_store);
967
968    let client_config = if let (Some(cert_path), Some(key_path)) =
969        (&config.client_cert, &config.client_key)
970    {
971        // Load client certificate for mTLS
972        let cert_file = File::open(cert_path)
973            .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
974        let mut cert_reader = BufReader::new(cert_file);
975
976        let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
977            .collect::<Result<Vec<_>, _>>()
978            .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
979
980        if certs.is_empty() {
981            return Err(TlsError::CertificateLoad(format!(
982                "{}: No certificates found",
983                cert_path.display()
984            )));
985        }
986
987        // Load client private key
988        let key_file = File::open(key_path)
989            .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
990        let mut key_reader = BufReader::new(key_file);
991
992        let key = rustls_pemfile::private_key(&mut key_reader)
993            .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
994            .ok_or_else(|| {
995                TlsError::KeyLoad(format!("{}: No private key found", key_path.display()))
996            })?;
997
998        info!(
999            cert_file = %cert_path.display(),
1000            "Configured mTLS client certificate for upstream connections"
1001        );
1002
1003        builder
1004            .with_client_auth_cert(certs, key)
1005            .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to set client auth: {}", e)))?
1006    } else {
1007        // No client certificate
1008        builder.with_no_client_auth()
1009    };
1010
1011    debug!("Upstream TLS configuration built successfully");
1012    Ok(client_config)
1013}
1014
1015/// Validate upstream TLS configuration
1016pub fn validate_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<(), TlsError> {
1017    // Validate CA certificate if specified
1018    if let Some(ca_path) = &config.ca_cert {
1019        if !ca_path.exists() {
1020            return Err(TlsError::CertificateLoad(format!(
1021                "Upstream CA certificate not found: {}",
1022                ca_path.display()
1023            )));
1024        }
1025    }
1026
1027    // Validate client certificate pair if mTLS is configured
1028    if let Some(cert_path) = &config.client_cert {
1029        if !cert_path.exists() {
1030            return Err(TlsError::CertificateLoad(format!(
1031                "Upstream client certificate not found: {}",
1032                cert_path.display()
1033            )));
1034        }
1035
1036        // If cert is specified, key must also be specified
1037        match &config.client_key {
1038            Some(key_path) if !key_path.exists() => {
1039                return Err(TlsError::KeyLoad(format!(
1040                    "Upstream client key not found: {}",
1041                    key_path.display()
1042                )));
1043            }
1044            None => {
1045                return Err(TlsError::ConfigBuild(
1046                    "client_cert specified without client_key".to_string(),
1047                ));
1048            }
1049            _ => {}
1050        }
1051    }
1052
1053    if config.client_key.is_some() && config.client_cert.is_none() {
1054        return Err(TlsError::ConfigBuild(
1055            "client_key specified without client_cert".to_string(),
1056        ));
1057    }
1058
1059    Ok(())
1060}
1061
1062// ============================================================================
1063// Certificate Loading Functions
1064// ============================================================================
1065
1066/// Load a certificate chain and private key from files
1067fn load_certified_key(cert_path: &Path, key_path: &Path) -> Result<CertifiedKey, TlsError> {
1068    // Load certificate chain
1069    let cert_file = File::open(cert_path)
1070        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1071    let mut cert_reader = BufReader::new(cert_file);
1072
1073    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
1074        .collect::<Result<Vec<_>, _>>()
1075        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1076
1077    if certs.is_empty() {
1078        return Err(TlsError::CertificateLoad(format!(
1079            "{}: No certificates found in file",
1080            cert_path.display()
1081        )));
1082    }
1083
1084    // Load private key
1085    let key_file = File::open(key_path)
1086        .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
1087    let mut key_reader = BufReader::new(key_file);
1088
1089    let key = rustls_pemfile::private_key(&mut key_reader)
1090        .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
1091        .ok_or_else(|| {
1092            TlsError::KeyLoad(format!(
1093                "{}: No private key found in file",
1094                key_path.display()
1095            ))
1096        })?;
1097
1098    // Create signing key using the default crypto provider
1099    let provider = rustls::crypto::CryptoProvider::get_default()
1100        .cloned()
1101        .unwrap_or_else(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider()));
1102
1103    let signing_key = provider
1104        .key_provider
1105        .load_private_key(key)
1106        .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to load private key: {:?}", e)))?;
1107
1108    Ok(CertifiedKey::new(certs, signing_key))
1109}
1110
1111/// Load CA certificates for client verification (mTLS)
1112pub fn load_client_ca(ca_path: &Path) -> Result<RootCertStore, TlsError> {
1113    let ca_file = File::open(ca_path)
1114        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
1115    let mut ca_reader = BufReader::new(ca_file);
1116
1117    let mut root_store = RootCertStore::empty();
1118
1119    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
1120        .collect::<Result<Vec<_>, _>>()
1121        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
1122
1123    for cert in certs {
1124        root_store.add(cert).map_err(|e| {
1125            TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
1126        })?;
1127    }
1128
1129    if root_store.is_empty() {
1130        return Err(TlsError::CertificateLoad(format!(
1131            "{}: No CA certificates found",
1132            ca_path.display()
1133        )));
1134    }
1135
1136    info!(
1137        ca_file = %ca_path.display(),
1138        cert_count = root_store.len(),
1139        "Loaded client CA certificates"
1140    );
1141
1142    Ok(root_store)
1143}
1144
1145/// Build a TLS ServerConfig from our configuration
1146pub fn build_server_config(config: &TlsConfig) -> Result<ServerConfig, TlsError> {
1147    let resolver = SniResolver::from_config(config)?;
1148
1149    let builder = ServerConfig::builder();
1150
1151    // Configure client authentication (mTLS)
1152    let server_config = if config.client_auth {
1153        if let Some(ca_path) = &config.ca_file {
1154            let root_store = load_client_ca(ca_path)?;
1155            let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
1156                .build()
1157                .map_err(|e| {
1158                    TlsError::ConfigBuild(format!("Failed to build client verifier: {}", e))
1159                })?;
1160
1161            info!("mTLS enabled: client certificates required");
1162
1163            builder
1164                .with_client_cert_verifier(verifier)
1165                .with_cert_resolver(Arc::new(resolver))
1166        } else {
1167            warn!("client_auth enabled but no ca_file specified, disabling client auth");
1168            builder
1169                .with_no_client_auth()
1170                .with_cert_resolver(Arc::new(resolver))
1171        }
1172    } else {
1173        builder
1174            .with_no_client_auth()
1175            .with_cert_resolver(Arc::new(resolver))
1176    };
1177
1178    // Configure ALPN for HTTP/2 support
1179    let mut config = server_config;
1180    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
1181
1182    debug!("TLS configuration built successfully");
1183
1184    Ok(config)
1185}
1186
1187/// Validate TLS configuration files exist and are readable
1188pub fn validate_tls_config(config: &TlsConfig) -> Result<(), TlsError> {
1189    // If ACME is configured, skip manual cert file validation
1190    if config.acme.is_some() {
1191        // ACME-managed certificates don't need cert_file/key_file to exist
1192        trace!("Skipping manual cert validation for ACME-managed TLS");
1193    } else {
1194        // Check default certificate (required for non-ACME configs)
1195        match (&config.cert_file, &config.key_file) {
1196            (Some(cert_file), Some(key_file)) => {
1197                if !cert_file.exists() {
1198                    return Err(TlsError::CertificateLoad(format!(
1199                        "Certificate file not found: {}",
1200                        cert_file.display()
1201                    )));
1202                }
1203                if !key_file.exists() {
1204                    return Err(TlsError::KeyLoad(format!(
1205                        "Key file not found: {}",
1206                        key_file.display()
1207                    )));
1208                }
1209            }
1210            _ => {
1211                return Err(TlsError::ConfigBuild(
1212                    "TLS configuration requires cert_file and key_file (or ACME block)".to_string(),
1213                ));
1214            }
1215        }
1216    }
1217
1218    // Check SNI certificates
1219    for sni in &config.additional_certs {
1220        if !sni.cert_file.exists() {
1221            return Err(TlsError::CertificateLoad(format!(
1222                "SNI certificate file not found: {}",
1223                sni.cert_file.display()
1224            )));
1225        }
1226        if !sni.key_file.exists() {
1227            return Err(TlsError::KeyLoad(format!(
1228                "SNI key file not found: {}",
1229                sni.key_file.display()
1230            )));
1231        }
1232    }
1233
1234    // Check CA file if mTLS enabled
1235    if config.client_auth {
1236        if let Some(ca_path) = &config.ca_file {
1237            if !ca_path.exists() {
1238                return Err(TlsError::CertificateLoad(format!(
1239                    "CA certificate file not found: {}",
1240                    ca_path.display()
1241                )));
1242            }
1243        }
1244    }
1245
1246    Ok(())
1247}
1248
1249#[cfg(test)]
1250mod tests {
1251
1252    #[test]
1253    fn test_wildcard_matching() {
1254        // Create a mock resolver without actual certs
1255        // Just test the matching logic
1256        let name = "foo.bar.example.com";
1257        let parts: Vec<&str> = name.split('.').collect();
1258
1259        assert_eq!(parts.len(), 4);
1260
1261        // Check domain extraction for wildcard matching
1262        let domain1 = parts[1..].join(".");
1263        assert_eq!(domain1, "bar.example.com");
1264
1265        let domain2 = parts[2..].join(".");
1266        assert_eq!(domain2, "example.com");
1267    }
1268
1269    #[test]
1270    fn test_hostname_normalization() {
1271        let hostname = "Example.COM";
1272        let normalized = hostname.to_lowercase();
1273        assert_eq!(normalized, "example.com");
1274    }
1275}