sentinel_proxy/
tls.rs

1//! TLS Configuration and SNI Support
2//!
3//! This module provides TLS configuration with Server Name Indication (SNI) support
4//! for serving multiple certificates based on the requested hostname.
5//!
6//! # Features
7//!
8//! - SNI-based certificate selection
9//! - Wildcard certificate matching (e.g., `*.example.com`)
10//! - Default certificate fallback
11//! - Certificate validation at startup
12//! - mTLS client certificate verification
13//! - Certificate hot-reload on SIGHUP
14//! - OCSP stapling support
15//!
16//! # Example KDL Configuration
17//!
18//! ```kdl
19//! listener "https" {
20//!     address "0.0.0.0:443"
21//!     protocol "https"
22//!     tls {
23//!         cert-file "/etc/certs/default.crt"
24//!         key-file "/etc/certs/default.key"
25//!
26//!         // SNI certificates
27//!         sni {
28//!             hostnames "example.com" "www.example.com"
29//!             cert-file "/etc/certs/example.crt"
30//!             key-file "/etc/certs/example.key"
31//!         }
32//!         sni {
33//!             hostnames "*.api.example.com"
34//!             cert-file "/etc/certs/api-wildcard.crt"
35//!             key-file "/etc/certs/api-wildcard.key"
36//!         }
37//!
38//!         // mTLS configuration
39//!         ca-file "/etc/certs/ca.crt"
40//!         client-auth true
41//!
42//!         // OCSP stapling
43//!         ocsp-stapling true
44//!     }
45//! }
46//! ```
47
48use std::collections::HashMap;
49use std::fs::File;
50use std::io::BufReader;
51use std::path::Path;
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54
55use parking_lot::RwLock;
56use rustls::client::ClientConfig;
57use rustls::pki_types::CertificateDer;
58use rustls::server::{ClientHello, ResolvesServerCert};
59use rustls::sign::CertifiedKey;
60use rustls::{RootCertStore, ServerConfig};
61use tracing::{debug, error, info, trace, warn};
62
63use sentinel_config::{TlsConfig, UpstreamTlsConfig};
64
65/// Error type for TLS operations
66#[derive(Debug)]
67pub enum TlsError {
68    /// Failed to load certificate file
69    CertificateLoad(String),
70    /// Failed to load private key file
71    KeyLoad(String),
72    /// Failed to build TLS configuration
73    ConfigBuild(String),
74    /// Certificate/key mismatch
75    CertKeyMismatch(String),
76    /// Invalid certificate
77    InvalidCertificate(String),
78    /// OCSP fetch error
79    OcspFetch(String),
80}
81
82impl std::fmt::Display for TlsError {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        match self {
85            TlsError::CertificateLoad(e) => write!(f, "Failed to load certificate: {}", e),
86            TlsError::KeyLoad(e) => write!(f, "Failed to load private key: {}", e),
87            TlsError::ConfigBuild(e) => write!(f, "Failed to build TLS config: {}", e),
88            TlsError::CertKeyMismatch(e) => write!(f, "Certificate/key mismatch: {}", e),
89            TlsError::InvalidCertificate(e) => write!(f, "Invalid certificate: {}", e),
90            TlsError::OcspFetch(e) => write!(f, "Failed to fetch OCSP response: {}", e),
91        }
92    }
93}
94
95impl std::error::Error for TlsError {}
96
97/// SNI-aware certificate resolver
98///
99/// Resolves certificates based on the Server Name Indication (SNI) extension
100/// in the TLS handshake. Supports:
101/// - Exact hostname matches
102/// - Wildcard certificates (e.g., `*.example.com`)
103/// - Default certificate fallback
104#[derive(Debug)]
105pub struct SniResolver {
106    /// Default certificate (used when no SNI match)
107    default_cert: Arc<CertifiedKey>,
108    /// SNI hostname to certificate mapping
109    /// Key is lowercase hostname, value is the certified key
110    sni_certs: HashMap<String, Arc<CertifiedKey>>,
111    /// Wildcard certificates (e.g., "*.example.com" -> cert)
112    wildcard_certs: HashMap<String, Arc<CertifiedKey>>,
113}
114
115impl SniResolver {
116    /// Create a new SNI resolver from TLS configuration
117    pub fn from_config(config: &TlsConfig) -> Result<Self, TlsError> {
118        // Load default certificate
119        let default_cert = load_certified_key(&config.cert_file, &config.key_file)?;
120
121        info!(
122            cert_file = %config.cert_file.display(),
123            "Loaded default TLS certificate"
124        );
125
126        let mut sni_certs = HashMap::new();
127        let mut wildcard_certs = HashMap::new();
128
129        // Load SNI certificates
130        for sni_config in &config.additional_certs {
131            let cert = load_certified_key(&sni_config.cert_file, &sni_config.key_file)?;
132            let cert = Arc::new(cert);
133
134            for hostname in &sni_config.hostnames {
135                let hostname_lower = hostname.to_lowercase();
136
137                if hostname_lower.starts_with("*.") {
138                    // Wildcard certificate
139                    let domain = hostname_lower.strip_prefix("*.").unwrap().to_string();
140                    wildcard_certs.insert(domain.clone(), cert.clone());
141                    debug!(
142                        pattern = %hostname,
143                        domain = %domain,
144                        cert_file = %sni_config.cert_file.display(),
145                        "Registered wildcard SNI certificate"
146                    );
147                } else {
148                    // Exact hostname match
149                    sni_certs.insert(hostname_lower.clone(), cert.clone());
150                    debug!(
151                        hostname = %hostname_lower,
152                        cert_file = %sni_config.cert_file.display(),
153                        "Registered SNI certificate"
154                    );
155                }
156            }
157        }
158
159        info!(
160            exact_certs = sni_certs.len(),
161            wildcard_certs = wildcard_certs.len(),
162            "SNI resolver initialized"
163        );
164
165        Ok(Self {
166            default_cert: Arc::new(default_cert),
167            sni_certs,
168            wildcard_certs,
169        })
170    }
171
172    /// Resolve certificate for a given server name
173    ///
174    /// This is the core resolution logic. For the rustls trait implementation,
175    /// see `ResolvesServerCert`.
176    pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
177        let Some(name) = server_name else {
178            debug!("No SNI provided, using default certificate");
179            return self.default_cert.clone();
180        };
181
182        let name_lower = name.to_lowercase();
183
184        // Try exact match first
185        if let Some(cert) = self.sni_certs.get(&name_lower) {
186            debug!(hostname = %name_lower, "SNI exact match found");
187            return cert.clone();
188        }
189
190        // Try wildcard match
191        // For "foo.bar.example.com", try "bar.example.com", then "example.com"
192        let parts: Vec<&str> = name_lower.split('.').collect();
193        for i in 1..parts.len() {
194            let domain = parts[i..].join(".");
195            if let Some(cert) = self.wildcard_certs.get(&domain) {
196                debug!(
197                    hostname = %name_lower,
198                    wildcard_domain = %domain,
199                    "SNI wildcard match found"
200                );
201                return cert.clone();
202            }
203        }
204
205        debug!(
206            hostname = %name_lower,
207            "No SNI match found, using default certificate"
208        );
209        self.default_cert.clone()
210    }
211}
212
213impl ResolvesServerCert for SniResolver {
214    fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
215        Some(self.resolve(client_hello.server_name()))
216    }
217}
218
219// ============================================================================
220// Hot-Reloadable Certificate Support
221// ============================================================================
222
223/// Hot-reloadable SNI certificate resolver
224///
225/// Wraps an SniResolver behind an RwLock to allow certificate hot-reload
226/// without restarting the server. On SIGHUP, the inner resolver is replaced
227/// with a newly loaded one.
228pub struct HotReloadableSniResolver {
229    /// Inner resolver (protected by RwLock for hot-reload)
230    inner: RwLock<Arc<SniResolver>>,
231    /// Original config for reloading
232    config: RwLock<TlsConfig>,
233    /// Last reload time
234    last_reload: RwLock<Instant>,
235}
236
237impl std::fmt::Debug for HotReloadableSniResolver {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        f.debug_struct("HotReloadableSniResolver")
240            .field("last_reload", &*self.last_reload.read())
241            .finish()
242    }
243}
244
245impl HotReloadableSniResolver {
246    /// Create a new hot-reloadable resolver from TLS configuration
247    pub fn from_config(config: TlsConfig) -> Result<Self, TlsError> {
248        let resolver = SniResolver::from_config(&config)?;
249
250        Ok(Self {
251            inner: RwLock::new(Arc::new(resolver)),
252            config: RwLock::new(config),
253            last_reload: RwLock::new(Instant::now()),
254        })
255    }
256
257    /// Reload certificates from disk
258    ///
259    /// This is called on SIGHUP to pick up new certificates without restart.
260    /// If the reload fails, the old certificates continue to be used.
261    pub fn reload(&self) -> Result<(), TlsError> {
262        let config = self.config.read();
263
264        info!(
265            cert_file = %config.cert_file.display(),
266            sni_count = config.additional_certs.len(),
267            "Reloading TLS certificates"
268        );
269
270        // Try to load new certificates
271        let new_resolver = SniResolver::from_config(&config)?;
272
273        // Swap in the new resolver atomically
274        *self.inner.write() = Arc::new(new_resolver);
275        *self.last_reload.write() = Instant::now();
276
277        info!("TLS certificates reloaded successfully");
278        Ok(())
279    }
280
281    /// Update configuration and reload
282    pub fn update_config(&self, new_config: TlsConfig) -> Result<(), TlsError> {
283        // Load with new config first
284        let new_resolver = SniResolver::from_config(&new_config)?;
285
286        // Update both config and resolver
287        *self.config.write() = new_config;
288        *self.inner.write() = Arc::new(new_resolver);
289        *self.last_reload.write() = Instant::now();
290
291        info!("TLS configuration updated and certificates reloaded");
292        Ok(())
293    }
294
295    /// Get time since last reload
296    pub fn last_reload_age(&self) -> Duration {
297        self.last_reload.read().elapsed()
298    }
299
300    /// Resolve certificate for a given server name
301    ///
302    /// This is the core resolution logic exposed for testing.
303    pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
304        self.inner.read().resolve(server_name)
305    }
306}
307
308impl ResolvesServerCert for HotReloadableSniResolver {
309    fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
310        Some(self.inner.read().resolve(client_hello.server_name()))
311    }
312}
313
314/// Certificate reload manager
315///
316/// Tracks all TLS listeners and provides a unified reload interface.
317pub struct CertificateReloader {
318    /// Map of listener ID to hot-reloadable resolver
319    resolvers: RwLock<HashMap<String, Arc<HotReloadableSniResolver>>>,
320}
321
322impl CertificateReloader {
323    /// Create a new certificate reloader
324    pub fn new() -> Self {
325        Self {
326            resolvers: RwLock::new(HashMap::new()),
327        }
328    }
329
330    /// Register a resolver for a listener
331    pub fn register(&self, listener_id: &str, resolver: Arc<HotReloadableSniResolver>) {
332        debug!(listener_id = %listener_id, "Registering TLS resolver for hot-reload");
333        self.resolvers
334            .write()
335            .insert(listener_id.to_string(), resolver);
336    }
337
338    /// Reload all registered certificates
339    ///
340    /// Returns the number of successfully reloaded listeners and any errors.
341    pub fn reload_all(&self) -> (usize, Vec<(String, TlsError)>) {
342        let resolvers = self.resolvers.read();
343        let mut success_count = 0;
344        let mut errors = Vec::new();
345
346        info!(
347            listener_count = resolvers.len(),
348            "Reloading certificates for all TLS listeners"
349        );
350
351        for (listener_id, resolver) in resolvers.iter() {
352            match resolver.reload() {
353                Ok(()) => {
354                    success_count += 1;
355                    debug!(listener_id = %listener_id, "Certificate reload successful");
356                }
357                Err(e) => {
358                    error!(listener_id = %listener_id, error = %e, "Certificate reload failed");
359                    errors.push((listener_id.clone(), e));
360                }
361            }
362        }
363
364        if errors.is_empty() {
365            info!(
366                success_count = success_count,
367                "All certificates reloaded successfully"
368            );
369        } else {
370            warn!(
371                success_count = success_count,
372                error_count = errors.len(),
373                "Certificate reload completed with errors"
374            );
375        }
376
377        (success_count, errors)
378    }
379
380    /// Get reload status for all listeners
381    pub fn status(&self) -> HashMap<String, Duration> {
382        self.resolvers
383            .read()
384            .iter()
385            .map(|(id, resolver)| (id.clone(), resolver.last_reload_age()))
386            .collect()
387    }
388}
389
390impl Default for CertificateReloader {
391    fn default() -> Self {
392        Self::new()
393    }
394}
395
396// ============================================================================
397// OCSP Stapling Support
398// ============================================================================
399
400/// OCSP response cache entry
401#[derive(Debug, Clone)]
402pub struct OcspCacheEntry {
403    /// DER-encoded OCSP response
404    pub response: Vec<u8>,
405    /// When this response was fetched
406    pub fetched_at: Instant,
407    /// When this response expires (from nextUpdate field)
408    pub expires_at: Option<Instant>,
409}
410
411/// OCSP stapling manager
412///
413/// Fetches and caches OCSP responses for certificates.
414pub struct OcspStapler {
415    /// Cache of OCSP responses by certificate fingerprint
416    cache: RwLock<HashMap<String, OcspCacheEntry>>,
417    /// Refresh interval for OCSP responses (default 1 hour)
418    refresh_interval: Duration,
419}
420
421impl OcspStapler {
422    /// Create a new OCSP stapler
423    pub fn new() -> Self {
424        Self {
425            cache: RwLock::new(HashMap::new()),
426            refresh_interval: Duration::from_secs(3600), // 1 hour default
427        }
428    }
429
430    /// Create with custom refresh interval
431    pub fn with_refresh_interval(interval: Duration) -> Self {
432        Self {
433            cache: RwLock::new(HashMap::new()),
434            refresh_interval: interval,
435        }
436    }
437
438    /// Get cached OCSP response for a certificate
439    pub fn get_response(&self, cert_fingerprint: &str) -> Option<Vec<u8>> {
440        let cache = self.cache.read();
441        if let Some(entry) = cache.get(cert_fingerprint) {
442            // Check if response is still valid
443            if entry.fetched_at.elapsed() < self.refresh_interval {
444                trace!(fingerprint = %cert_fingerprint, "OCSP cache hit");
445                return Some(entry.response.clone());
446            }
447            trace!(fingerprint = %cert_fingerprint, "OCSP cache expired");
448        }
449        None
450    }
451
452    /// Fetch OCSP response for a certificate
453    ///
454    /// This performs an HTTP request to the OCSP responder specified in the
455    /// certificate's Authority Information Access extension.
456    pub fn fetch_ocsp_response(
457        &self,
458        cert_der: &[u8],
459        issuer_der: &[u8],
460    ) -> Result<Vec<u8>, TlsError> {
461        use x509_parser::prelude::*;
462
463        // Parse the end-entity certificate
464        let (_, cert) = X509Certificate::from_der(cert_der)
465            .map_err(|e| TlsError::OcspFetch(format!("Failed to parse certificate: {}", e)))?;
466
467        // Parse the issuer certificate
468        let (_, issuer) = X509Certificate::from_der(issuer_der)
469            .map_err(|e| TlsError::OcspFetch(format!("Failed to parse issuer certificate: {}", e)))?;
470
471        // Extract OCSP responder URL from AIA extension
472        let ocsp_url = extract_ocsp_responder_url(&cert)?;
473        debug!(url = %ocsp_url, "Found OCSP responder URL");
474
475        // Build OCSP request
476        let ocsp_request = build_ocsp_request(&cert, &issuer)?;
477
478        // Send request synchronously (blocking context)
479        // Note: In production, this should be async with proper timeout handling
480        let response = send_ocsp_request_sync(&ocsp_url, &ocsp_request)?;
481
482        // Calculate fingerprint for caching
483        let fingerprint = calculate_cert_fingerprint(cert_der);
484
485        // Cache the response
486        let entry = OcspCacheEntry {
487            response: response.clone(),
488            fetched_at: Instant::now(),
489            expires_at: None, // Could parse nextUpdate from response
490        };
491        self.cache.write().insert(fingerprint, entry);
492
493        info!("Successfully fetched and cached OCSP response");
494        Ok(response)
495    }
496
497    /// Async version of fetch_ocsp_response
498    pub async fn fetch_ocsp_response_async(
499        &self,
500        cert_der: &[u8],
501        issuer_der: &[u8],
502    ) -> Result<Vec<u8>, TlsError> {
503        use x509_parser::prelude::*;
504
505        // Parse the end-entity certificate
506        let (_, cert) = X509Certificate::from_der(cert_der)
507            .map_err(|e| TlsError::OcspFetch(format!("Failed to parse certificate: {}", e)))?;
508
509        // Parse the issuer certificate
510        let (_, issuer) = X509Certificate::from_der(issuer_der)
511            .map_err(|e| TlsError::OcspFetch(format!("Failed to parse issuer certificate: {}", e)))?;
512
513        // Extract OCSP responder URL from AIA extension
514        let ocsp_url = extract_ocsp_responder_url(&cert)?;
515        debug!(url = %ocsp_url, "Found OCSP responder URL");
516
517        // Build OCSP request
518        let ocsp_request = build_ocsp_request(&cert, &issuer)?;
519
520        // Send request asynchronously
521        let response = send_ocsp_request_async(&ocsp_url, &ocsp_request).await?;
522
523        // Calculate fingerprint for caching
524        let fingerprint = calculate_cert_fingerprint(cert_der);
525
526        // Cache the response
527        let entry = OcspCacheEntry {
528            response: response.clone(),
529            fetched_at: Instant::now(),
530            expires_at: None,
531        };
532        self.cache.write().insert(fingerprint, entry);
533
534        info!("Successfully fetched and cached OCSP response (async)");
535        Ok(response)
536    }
537
538    /// Prefetch OCSP responses for all certificates in a config
539    pub fn prefetch_for_config(&self, config: &TlsConfig) -> Vec<String> {
540        let mut warnings = Vec::new();
541
542        if !config.ocsp_stapling {
543            trace!("OCSP stapling disabled in config");
544            return warnings;
545        }
546
547        info!("Prefetching OCSP responses for certificates");
548
549        // For now, just log that we would prefetch
550        // Full implementation would iterate certificates and fetch OCSP responses
551        warnings.push("OCSP stapling prefetch not yet fully implemented".to_string());
552
553        warnings
554    }
555
556    /// Clear the OCSP cache
557    pub fn clear_cache(&self) {
558        self.cache.write().clear();
559        info!("OCSP cache cleared");
560    }
561}
562
563impl Default for OcspStapler {
564    fn default() -> Self {
565        Self::new()
566    }
567}
568
569// ============================================================================
570// OCSP Helper Functions
571// ============================================================================
572
573/// Extract OCSP responder URL from certificate's Authority Information Access extension
574fn extract_ocsp_responder_url(cert: &x509_parser::certificate::X509Certificate) -> Result<String, TlsError> {
575    use x509_parser::prelude::*;
576
577    // Find the AIA extension
578    let aia = cert
579        .extensions()
580        .iter()
581        .find(|ext| ext.oid == oid_registry::OID_PKIX_AUTHORITY_INFO_ACCESS)
582        .ok_or_else(|| TlsError::OcspFetch(
583            "Certificate does not have Authority Information Access extension".to_string()
584        ))?;
585
586    // Parse AIA extension
587    let aia_value = match aia.parsed_extension() {
588        ParsedExtension::AuthorityInfoAccess(aia) => aia,
589        _ => return Err(TlsError::OcspFetch(
590            "Failed to parse Authority Information Access extension".to_string()
591        )),
592    };
593
594    // Find OCSP access method
595    for access in &aia_value.accessdescs {
596        if access.access_method == oid_registry::OID_PKIX_ACCESS_DESCRIPTOR_OCSP {
597            match &access.access_location {
598                GeneralName::URI(url) => {
599                    return Ok(url.to_string());
600                }
601                _ => continue,
602            }
603        }
604    }
605
606    Err(TlsError::OcspFetch(
607        "Certificate AIA does not contain OCSP responder URL".to_string()
608    ))
609}
610
611/// Build an OCSP request for the given certificate
612///
613/// This builds a minimal OCSP request with SHA-256 hashes
614fn build_ocsp_request(
615    cert: &x509_parser::certificate::X509Certificate,
616    issuer: &x509_parser::certificate::X509Certificate,
617) -> Result<Vec<u8>, TlsError> {
618    use sha2::{Sha256, Digest};
619
620    // Per RFC 6960, an OCSP request contains:
621    // - Hash of issuer name
622    // - Hash of issuer public key
623    // - Certificate serial number
624
625    // Hash issuer name (Distinguished Name)
626    let issuer_name_hash = {
627        let mut hasher = Sha256::new();
628        hasher.update(issuer.subject().as_raw());
629        hasher.finalize()
630    };
631
632    // Hash issuer public key (the BIT STRING content, not including tag/length)
633    let issuer_key_hash = {
634        let mut hasher = Sha256::new();
635        hasher.update(issuer.public_key().subject_public_key.data.as_ref());
636        hasher.finalize()
637    };
638
639    // Get certificate serial number
640    let serial = cert.serial.to_bytes_be();
641
642    // Build ASN.1 DER encoded OCSP request
643    // This is a minimal implementation of the OCSP request structure
644    let request = build_ocsp_request_der(
645        &issuer_name_hash,
646        &issuer_key_hash,
647        &serial,
648    );
649
650    Ok(request)
651}
652
653/// Build DER-encoded OCSP request
654fn build_ocsp_request_der(
655    issuer_name_hash: &[u8],
656    issuer_key_hash: &[u8],
657    serial_number: &[u8],
658) -> Vec<u8> {
659    // OID for SHA-256
660    let sha256_oid: &[u8] = &[0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01];
661
662    // Build CertID structure
663    let hash_algorithm = der_sequence(&[
664        &der_oid(sha256_oid),
665        &der_null(),
666    ]);
667
668    let cert_id = der_sequence(&[
669        &hash_algorithm,
670        &der_octet_string(issuer_name_hash),
671        &der_octet_string(issuer_key_hash),
672        &der_integer(serial_number),
673    ]);
674
675    // Build Request structure
676    let request = der_sequence(&[&cert_id]);
677
678    // Build requestList (SEQUENCE OF Request)
679    let request_list = der_sequence(&[&request]);
680
681    // Build TBSRequest
682    let tbs_request = der_sequence(&[&request_list]);
683
684    // Build OCSPRequest
685    der_sequence(&[&tbs_request])
686}
687
688// DER encoding helpers
689fn der_sequence(items: &[&[u8]]) -> Vec<u8> {
690    let mut content = Vec::new();
691    for item in items {
692        content.extend_from_slice(item);
693    }
694    let mut result = vec![0x30]; // SEQUENCE tag
695    result.extend(der_length(content.len()));
696    result.extend(content);
697    result
698}
699
700fn der_oid(oid: &[u8]) -> Vec<u8> {
701    let mut result = vec![0x06]; // OID tag
702    result.extend(der_length(oid.len()));
703    result.extend_from_slice(oid);
704    result
705}
706
707fn der_null() -> Vec<u8> {
708    vec![0x05, 0x00] // NULL
709}
710
711fn der_octet_string(data: &[u8]) -> Vec<u8> {
712    let mut result = vec![0x04]; // OCTET STRING tag
713    result.extend(der_length(data.len()));
714    result.extend_from_slice(data);
715    result
716}
717
718fn der_integer(data: &[u8]) -> Vec<u8> {
719    let mut result = vec![0x02]; // INTEGER tag
720    // Remove leading zeros but ensure at least one byte
721    let data = match data.iter().position(|&b| b != 0) {
722        Some(pos) => &data[pos..],
723        None => &[0],
724    };
725    // Add leading zero if high bit is set (to ensure positive)
726    if !data.is_empty() && data[0] & 0x80 != 0 {
727        result.extend(der_length(data.len() + 1));
728        result.push(0x00);
729    } else {
730        result.extend(der_length(data.len()));
731    }
732    result.extend_from_slice(data);
733    result
734}
735
736fn der_length(len: usize) -> Vec<u8> {
737    if len < 128 {
738        vec![len as u8]
739    } else if len < 256 {
740        vec![0x81, len as u8]
741    } else {
742        vec![0x82, (len >> 8) as u8, len as u8]
743    }
744}
745
746/// Send OCSP request synchronously (blocking)
747fn send_ocsp_request_sync(url: &str, request: &[u8]) -> Result<Vec<u8>, TlsError> {
748    use std::io::{Read, Write};
749    use std::net::TcpStream;
750    use std::time::Duration;
751
752    // Parse URL to get host, port, and path
753    let url = url::Url::parse(url)
754        .map_err(|e| TlsError::OcspFetch(format!("Invalid OCSP URL: {}", e)))?;
755
756    let host = url.host_str()
757        .ok_or_else(|| TlsError::OcspFetch("OCSP URL has no host".to_string()))?;
758    let port = url.port().unwrap_or(80);
759    let path = if url.path().is_empty() { "/" } else { url.path() };
760
761    // Connect to server
762    let addr = format!("{}:{}", host, port);
763    let mut stream = TcpStream::connect(&addr)
764        .map_err(|e| TlsError::OcspFetch(format!("Failed to connect to OCSP responder: {}", e)))?;
765
766    stream.set_read_timeout(Some(Duration::from_secs(10)))
767        .map_err(|e| TlsError::OcspFetch(format!("Failed to set timeout: {}", e)))?;
768    stream.set_write_timeout(Some(Duration::from_secs(10)))
769        .map_err(|e| TlsError::OcspFetch(format!("Failed to set timeout: {}", e)))?;
770
771    // Build HTTP POST request
772    let http_request = format!(
773        "POST {} HTTP/1.1\r\n\
774         Host: {}\r\n\
775         Content-Type: application/ocsp-request\r\n\
776         Content-Length: {}\r\n\
777         Connection: close\r\n\
778         \r\n",
779        path, host, request.len()
780    );
781
782    // Send request
783    stream.write_all(http_request.as_bytes())
784        .map_err(|e| TlsError::OcspFetch(format!("Failed to send OCSP request: {}", e)))?;
785    stream.write_all(request)
786        .map_err(|e| TlsError::OcspFetch(format!("Failed to send OCSP request body: {}", e)))?;
787
788    // Read response
789    let mut response = Vec::new();
790    stream.read_to_end(&mut response)
791        .map_err(|e| TlsError::OcspFetch(format!("Failed to read OCSP response: {}", e)))?;
792
793    // Parse HTTP response - find body after headers
794    let headers_end = response.windows(4)
795        .position(|w| w == b"\r\n\r\n")
796        .ok_or_else(|| TlsError::OcspFetch("Invalid HTTP response: no headers end".to_string()))?;
797
798    let body = &response[headers_end + 4..];
799    if body.is_empty() {
800        return Err(TlsError::OcspFetch("Empty OCSP response body".to_string()));
801    }
802
803    Ok(body.to_vec())
804}
805
806/// Send OCSP request asynchronously
807async fn send_ocsp_request_async(url: &str, request: &[u8]) -> Result<Vec<u8>, TlsError> {
808    let client = reqwest::Client::builder()
809        .timeout(Duration::from_secs(10))
810        .build()
811        .map_err(|e| TlsError::OcspFetch(format!("Failed to create HTTP client: {}", e)))?;
812
813    let response = client
814        .post(url)
815        .header("Content-Type", "application/ocsp-request")
816        .body(request.to_vec())
817        .send()
818        .await
819        .map_err(|e| TlsError::OcspFetch(format!("OCSP request failed: {}", e)))?;
820
821    if !response.status().is_success() {
822        return Err(TlsError::OcspFetch(format!(
823            "OCSP responder returned status: {}",
824            response.status()
825        )));
826    }
827
828    let body = response.bytes().await
829        .map_err(|e| TlsError::OcspFetch(format!("Failed to read OCSP response: {}", e)))?;
830
831    Ok(body.to_vec())
832}
833
834/// Calculate certificate fingerprint for cache key
835fn calculate_cert_fingerprint(cert_der: &[u8]) -> String {
836    use sha2::{Sha256, Digest};
837    let mut hasher = Sha256::new();
838    hasher.update(cert_der);
839    let result = hasher.finalize();
840    hex::encode(result)
841}
842
843// ============================================================================
844// Upstream mTLS Support (Client Certificates)
845// ============================================================================
846
847/// Load client certificate and key for mTLS to upstreams
848///
849/// This function loads PEM-encoded certificates and private key and converts
850/// them to Pingora's CertKey format for use with `HttpPeer.client_cert_key`.
851///
852/// # Arguments
853///
854/// * `cert_path` - Path to PEM-encoded certificate (may include chain)
855/// * `key_path` - Path to PEM-encoded private key
856///
857/// # Returns
858///
859/// An `Arc<CertKey>` that can be set on `peer.client_cert_key` for mTLS
860pub fn load_client_cert_key(
861    cert_path: &Path,
862    key_path: &Path,
863) -> Result<Arc<pingora_core::utils::tls::CertKey>, TlsError> {
864    // Read certificate chain (PEM format, may contain intermediates)
865    let cert_file = File::open(cert_path)
866        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
867    let mut cert_reader = BufReader::new(cert_file);
868
869    // Parse certificates from PEM to DER
870    let cert_ders: Vec<Vec<u8>> = rustls_pemfile::certs(&mut cert_reader)
871        .collect::<Result<Vec<_>, _>>()
872        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?
873        .into_iter()
874        .map(|c| c.to_vec())
875        .collect();
876
877    if cert_ders.is_empty() {
878        return Err(TlsError::CertificateLoad(format!(
879            "{}: No certificates found in PEM file",
880            cert_path.display()
881        )));
882    }
883
884    // Read private key (PEM format)
885    let key_file = File::open(key_path)
886        .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
887    let mut key_reader = BufReader::new(key_file);
888
889    // Parse private key from PEM to DER
890    let key_der = rustls_pemfile::private_key(&mut key_reader)
891        .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
892        .ok_or_else(|| {
893            TlsError::KeyLoad(format!(
894                "{}: No private key found in PEM file",
895                key_path.display()
896            ))
897        })?
898        .secret_der()
899        .to_vec();
900
901    // Create Pingora's CertKey (certificates: Vec<Vec<u8>>, key: Vec<u8>)
902    let cert_key = pingora_core::utils::tls::CertKey::new(cert_ders, key_der);
903
904    debug!(
905        cert_path = %cert_path.display(),
906        key_path = %key_path.display(),
907        "Loaded mTLS client certificate for upstream connections"
908    );
909
910    Ok(Arc::new(cert_key))
911}
912
913/// Build a TLS client configuration for upstream connections with mTLS
914///
915/// This creates a rustls ClientConfig that can be used when Sentinel
916/// connects to backends that require client certificate authentication.
917pub fn build_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<ClientConfig, TlsError> {
918    let mut root_store = RootCertStore::empty();
919
920    // Load CA certificates for server verification
921    if let Some(ca_path) = &config.ca_cert {
922        let ca_file = File::open(ca_path)
923            .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
924        let mut ca_reader = BufReader::new(ca_file);
925
926        let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
927            .collect::<Result<Vec<_>, _>>()
928            .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
929
930        for cert in certs {
931            root_store.add(cert).map_err(|e| {
932                TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
933            })?;
934        }
935
936        debug!(
937            ca_file = %ca_path.display(),
938            cert_count = root_store.len(),
939            "Loaded upstream CA certificates"
940        );
941    } else if !config.insecure_skip_verify {
942        // Use webpki roots for standard TLS
943        root_store = RootCertStore {
944            roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
945        };
946        trace!("Using webpki-roots for upstream TLS verification");
947    }
948
949    // Build the client config
950    let builder = ClientConfig::builder().with_root_certificates(root_store);
951
952    let client_config = if let (Some(cert_path), Some(key_path)) =
953        (&config.client_cert, &config.client_key)
954    {
955        // Load client certificate for mTLS
956        let cert_file = File::open(cert_path)
957            .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
958        let mut cert_reader = BufReader::new(cert_file);
959
960        let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
961            .collect::<Result<Vec<_>, _>>()
962            .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
963
964        if certs.is_empty() {
965            return Err(TlsError::CertificateLoad(format!(
966                "{}: No certificates found",
967                cert_path.display()
968            )));
969        }
970
971        // Load client private key
972        let key_file = File::open(key_path)
973            .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
974        let mut key_reader = BufReader::new(key_file);
975
976        let key = rustls_pemfile::private_key(&mut key_reader)
977            .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
978            .ok_or_else(|| {
979                TlsError::KeyLoad(format!("{}: No private key found", key_path.display()))
980            })?;
981
982        info!(
983            cert_file = %cert_path.display(),
984            "Configured mTLS client certificate for upstream connections"
985        );
986
987        builder
988            .with_client_auth_cert(certs, key)
989            .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to set client auth: {}", e)))?
990    } else {
991        // No client certificate
992        builder.with_no_client_auth()
993    };
994
995    debug!("Upstream TLS configuration built successfully");
996    Ok(client_config)
997}
998
999/// Validate upstream TLS configuration
1000pub fn validate_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<(), TlsError> {
1001    // Validate CA certificate if specified
1002    if let Some(ca_path) = &config.ca_cert {
1003        if !ca_path.exists() {
1004            return Err(TlsError::CertificateLoad(format!(
1005                "Upstream CA certificate not found: {}",
1006                ca_path.display()
1007            )));
1008        }
1009    }
1010
1011    // Validate client certificate pair if mTLS is configured
1012    if let Some(cert_path) = &config.client_cert {
1013        if !cert_path.exists() {
1014            return Err(TlsError::CertificateLoad(format!(
1015                "Upstream client certificate not found: {}",
1016                cert_path.display()
1017            )));
1018        }
1019
1020        // If cert is specified, key must also be specified
1021        match &config.client_key {
1022            Some(key_path) if !key_path.exists() => {
1023                return Err(TlsError::KeyLoad(format!(
1024                    "Upstream client key not found: {}",
1025                    key_path.display()
1026                )));
1027            }
1028            None => {
1029                return Err(TlsError::ConfigBuild(
1030                    "client_cert specified without client_key".to_string(),
1031                ));
1032            }
1033            _ => {}
1034        }
1035    }
1036
1037    if config.client_key.is_some() && config.client_cert.is_none() {
1038        return Err(TlsError::ConfigBuild(
1039            "client_key specified without client_cert".to_string(),
1040        ));
1041    }
1042
1043    Ok(())
1044}
1045
1046// ============================================================================
1047// Certificate Loading Functions
1048// ============================================================================
1049
1050/// Load a certificate chain and private key from files
1051fn load_certified_key(cert_path: &Path, key_path: &Path) -> Result<CertifiedKey, TlsError> {
1052    // Load certificate chain
1053    let cert_file = File::open(cert_path)
1054        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1055    let mut cert_reader = BufReader::new(cert_file);
1056
1057    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
1058        .collect::<Result<Vec<_>, _>>()
1059        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1060
1061    if certs.is_empty() {
1062        return Err(TlsError::CertificateLoad(format!(
1063            "{}: No certificates found in file",
1064            cert_path.display()
1065        )));
1066    }
1067
1068    // Load private key
1069    let key_file = File::open(key_path)
1070        .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
1071    let mut key_reader = BufReader::new(key_file);
1072
1073    let key = rustls_pemfile::private_key(&mut key_reader)
1074        .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
1075        .ok_or_else(|| {
1076            TlsError::KeyLoad(format!(
1077                "{}: No private key found in file",
1078                key_path.display()
1079            ))
1080        })?;
1081
1082    // Create signing key using the default crypto provider
1083    let provider = rustls::crypto::CryptoProvider::get_default()
1084        .cloned()
1085        .unwrap_or_else(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider()));
1086
1087    let signing_key = provider
1088        .key_provider
1089        .load_private_key(key)
1090        .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to load private key: {:?}", e)))?;
1091
1092    Ok(CertifiedKey::new(certs, signing_key))
1093}
1094
1095/// Load CA certificates for client verification (mTLS)
1096pub fn load_client_ca(ca_path: &Path) -> Result<RootCertStore, TlsError> {
1097    let ca_file = File::open(ca_path)
1098        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
1099    let mut ca_reader = BufReader::new(ca_file);
1100
1101    let mut root_store = RootCertStore::empty();
1102
1103    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
1104        .collect::<Result<Vec<_>, _>>()
1105        .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
1106
1107    for cert in certs {
1108        root_store.add(cert).map_err(|e| {
1109            TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
1110        })?;
1111    }
1112
1113    if root_store.is_empty() {
1114        return Err(TlsError::CertificateLoad(format!(
1115            "{}: No CA certificates found",
1116            ca_path.display()
1117        )));
1118    }
1119
1120    info!(
1121        ca_file = %ca_path.display(),
1122        cert_count = root_store.len(),
1123        "Loaded client CA certificates"
1124    );
1125
1126    Ok(root_store)
1127}
1128
1129/// Build a TLS ServerConfig from our configuration
1130pub fn build_server_config(config: &TlsConfig) -> Result<ServerConfig, TlsError> {
1131    let resolver = SniResolver::from_config(config)?;
1132
1133    let builder = ServerConfig::builder();
1134
1135    // Configure client authentication (mTLS)
1136    let server_config = if config.client_auth {
1137        if let Some(ca_path) = &config.ca_file {
1138            let root_store = load_client_ca(ca_path)?;
1139            let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
1140                .build()
1141                .map_err(|e| {
1142                    TlsError::ConfigBuild(format!("Failed to build client verifier: {}", e))
1143                })?;
1144
1145            info!("mTLS enabled: client certificates required");
1146
1147            builder
1148                .with_client_cert_verifier(verifier)
1149                .with_cert_resolver(Arc::new(resolver))
1150        } else {
1151            warn!("client_auth enabled but no ca_file specified, disabling client auth");
1152            builder
1153                .with_no_client_auth()
1154                .with_cert_resolver(Arc::new(resolver))
1155        }
1156    } else {
1157        builder
1158            .with_no_client_auth()
1159            .with_cert_resolver(Arc::new(resolver))
1160    };
1161
1162    // Configure ALPN for HTTP/2 support
1163    let mut config = server_config;
1164    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
1165
1166    debug!("TLS configuration built successfully");
1167
1168    Ok(config)
1169}
1170
1171/// Validate TLS configuration files exist and are readable
1172pub fn validate_tls_config(config: &TlsConfig) -> Result<(), TlsError> {
1173    // Check default certificate
1174    if !config.cert_file.exists() {
1175        return Err(TlsError::CertificateLoad(format!(
1176            "Certificate file not found: {}",
1177            config.cert_file.display()
1178        )));
1179    }
1180    if !config.key_file.exists() {
1181        return Err(TlsError::KeyLoad(format!(
1182            "Key file not found: {}",
1183            config.key_file.display()
1184        )));
1185    }
1186
1187    // Check SNI certificates
1188    for sni in &config.additional_certs {
1189        if !sni.cert_file.exists() {
1190            return Err(TlsError::CertificateLoad(format!(
1191                "SNI certificate file not found: {}",
1192                sni.cert_file.display()
1193            )));
1194        }
1195        if !sni.key_file.exists() {
1196            return Err(TlsError::KeyLoad(format!(
1197                "SNI key file not found: {}",
1198                sni.key_file.display()
1199            )));
1200        }
1201    }
1202
1203    // Check CA file if mTLS enabled
1204    if config.client_auth {
1205        if let Some(ca_path) = &config.ca_file {
1206            if !ca_path.exists() {
1207                return Err(TlsError::CertificateLoad(format!(
1208                    "CA certificate file not found: {}",
1209                    ca_path.display()
1210                )));
1211            }
1212        }
1213    }
1214
1215    Ok(())
1216}
1217
1218#[cfg(test)]
1219mod tests {
1220
1221    #[test]
1222    fn test_wildcard_matching() {
1223        // Create a mock resolver without actual certs
1224        // Just test the matching logic
1225        let name = "foo.bar.example.com";
1226        let parts: Vec<&str> = name.split('.').collect();
1227
1228        assert_eq!(parts.len(), 4);
1229
1230        // Check domain extraction for wildcard matching
1231        let domain1 = parts[1..].join(".");
1232        assert_eq!(domain1, "bar.example.com");
1233
1234        let domain2 = parts[2..].join(".");
1235        assert_eq!(domain2, "example.com");
1236    }
1237
1238    #[test]
1239    fn test_hostname_normalization() {
1240        let hostname = "Example.COM";
1241        let normalized = hostname.to_lowercase();
1242        assert_eq!(normalized, "example.com");
1243    }
1244}