1use std::collections::HashMap;
49use std::fs::File;
50use std::io::BufReader;
51use std::path::Path;
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54
55use parking_lot::RwLock;
56use rustls::client::ClientConfig;
57use rustls::pki_types::CertificateDer;
58use rustls::server::{ClientHello, ResolvesServerCert};
59use rustls::sign::CertifiedKey;
60use rustls::{RootCertStore, ServerConfig};
61use tracing::{debug, error, info, trace, warn};
62
63use sentinel_config::{TlsConfig, UpstreamTlsConfig};
64
65#[derive(Debug)]
67pub enum TlsError {
68 CertificateLoad(String),
70 KeyLoad(String),
72 ConfigBuild(String),
74 CertKeyMismatch(String),
76 InvalidCertificate(String),
78 OcspFetch(String),
80}
81
82impl std::fmt::Display for TlsError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 TlsError::CertificateLoad(e) => write!(f, "Failed to load certificate: {}", e),
86 TlsError::KeyLoad(e) => write!(f, "Failed to load private key: {}", e),
87 TlsError::ConfigBuild(e) => write!(f, "Failed to build TLS config: {}", e),
88 TlsError::CertKeyMismatch(e) => write!(f, "Certificate/key mismatch: {}", e),
89 TlsError::InvalidCertificate(e) => write!(f, "Invalid certificate: {}", e),
90 TlsError::OcspFetch(e) => write!(f, "Failed to fetch OCSP response: {}", e),
91 }
92 }
93}
94
95impl std::error::Error for TlsError {}
96
97#[derive(Debug)]
105pub struct SniResolver {
106 default_cert: Arc<CertifiedKey>,
108 sni_certs: HashMap<String, Arc<CertifiedKey>>,
111 wildcard_certs: HashMap<String, Arc<CertifiedKey>>,
113}
114
115impl SniResolver {
116 pub fn from_config(config: &TlsConfig) -> Result<Self, TlsError> {
118 let (cert_file, key_file) = match (&config.cert_file, &config.key_file) {
120 (Some(cert), Some(key)) => (cert, key),
121 _ => {
122 return Err(TlsError::ConfigBuild(
123 "TLS configuration requires cert_file and key_file".to_string(),
124 ));
125 }
126 };
127
128 let default_cert = load_certified_key(cert_file, key_file)?;
130
131 info!(
132 cert_file = %cert_file.display(),
133 "Loaded default TLS certificate"
134 );
135
136 let mut sni_certs = HashMap::new();
137 let mut wildcard_certs = HashMap::new();
138
139 for sni_config in &config.additional_certs {
141 let cert = load_certified_key(&sni_config.cert_file, &sni_config.key_file)?;
142 let cert = Arc::new(cert);
143
144 for hostname in &sni_config.hostnames {
145 let hostname_lower = hostname.to_lowercase();
146
147 if hostname_lower.starts_with("*.") {
148 let domain = hostname_lower.strip_prefix("*.").unwrap().to_string();
150 wildcard_certs.insert(domain.clone(), cert.clone());
151 debug!(
152 pattern = %hostname,
153 domain = %domain,
154 cert_file = %sni_config.cert_file.display(),
155 "Registered wildcard SNI certificate"
156 );
157 } else {
158 sni_certs.insert(hostname_lower.clone(), cert.clone());
160 debug!(
161 hostname = %hostname_lower,
162 cert_file = %sni_config.cert_file.display(),
163 "Registered SNI certificate"
164 );
165 }
166 }
167 }
168
169 info!(
170 exact_certs = sni_certs.len(),
171 wildcard_certs = wildcard_certs.len(),
172 "SNI resolver initialized"
173 );
174
175 Ok(Self {
176 default_cert: Arc::new(default_cert),
177 sni_certs,
178 wildcard_certs,
179 })
180 }
181
182 pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
187 let Some(name) = server_name else {
188 debug!("No SNI provided, using default certificate");
189 return self.default_cert.clone();
190 };
191
192 let name_lower = name.to_lowercase();
193
194 if let Some(cert) = self.sni_certs.get(&name_lower) {
196 debug!(hostname = %name_lower, "SNI exact match found");
197 return cert.clone();
198 }
199
200 let parts: Vec<&str> = name_lower.split('.').collect();
203 for i in 1..parts.len() {
204 let domain = parts[i..].join(".");
205 if let Some(cert) = self.wildcard_certs.get(&domain) {
206 debug!(
207 hostname = %name_lower,
208 wildcard_domain = %domain,
209 "SNI wildcard match found"
210 );
211 return cert.clone();
212 }
213 }
214
215 debug!(
216 hostname = %name_lower,
217 "No SNI match found, using default certificate"
218 );
219 self.default_cert.clone()
220 }
221}
222
223impl ResolvesServerCert for SniResolver {
224 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
225 Some(self.resolve(client_hello.server_name()))
226 }
227}
228
229pub struct HotReloadableSniResolver {
239 inner: RwLock<Arc<SniResolver>>,
241 config: RwLock<TlsConfig>,
243 last_reload: RwLock<Instant>,
245}
246
247impl std::fmt::Debug for HotReloadableSniResolver {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 f.debug_struct("HotReloadableSniResolver")
250 .field("last_reload", &*self.last_reload.read())
251 .finish()
252 }
253}
254
255impl HotReloadableSniResolver {
256 pub fn from_config(config: TlsConfig) -> Result<Self, TlsError> {
258 let resolver = SniResolver::from_config(&config)?;
259
260 Ok(Self {
261 inner: RwLock::new(Arc::new(resolver)),
262 config: RwLock::new(config),
263 last_reload: RwLock::new(Instant::now()),
264 })
265 }
266
267 pub fn reload(&self) -> Result<(), TlsError> {
272 let config = self.config.read();
273
274 let cert_file_display = config
275 .cert_file
276 .as_ref()
277 .map(|p| p.display().to_string())
278 .unwrap_or_else(|| "(acme-managed)".to_string());
279
280 info!(
281 cert_file = %cert_file_display,
282 sni_count = config.additional_certs.len(),
283 "Reloading TLS certificates"
284 );
285
286 let new_resolver = SniResolver::from_config(&config)?;
288
289 *self.inner.write() = Arc::new(new_resolver);
291 *self.last_reload.write() = Instant::now();
292
293 info!("TLS certificates reloaded successfully");
294 Ok(())
295 }
296
297 pub fn update_config(&self, new_config: TlsConfig) -> Result<(), TlsError> {
299 let new_resolver = SniResolver::from_config(&new_config)?;
301
302 *self.config.write() = new_config;
304 *self.inner.write() = Arc::new(new_resolver);
305 *self.last_reload.write() = Instant::now();
306
307 info!("TLS configuration updated and certificates reloaded");
308 Ok(())
309 }
310
311 pub fn last_reload_age(&self) -> Duration {
313 self.last_reload.read().elapsed()
314 }
315
316 pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
320 self.inner.read().resolve(server_name)
321 }
322}
323
324impl ResolvesServerCert for HotReloadableSniResolver {
325 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
326 Some(self.inner.read().resolve(client_hello.server_name()))
327 }
328}
329
330pub struct CertificateReloader {
334 resolvers: RwLock<HashMap<String, Arc<HotReloadableSniResolver>>>,
336}
337
338impl CertificateReloader {
339 pub fn new() -> Self {
341 Self {
342 resolvers: RwLock::new(HashMap::new()),
343 }
344 }
345
346 pub fn register(&self, listener_id: &str, resolver: Arc<HotReloadableSniResolver>) {
348 debug!(listener_id = %listener_id, "Registering TLS resolver for hot-reload");
349 self.resolvers
350 .write()
351 .insert(listener_id.to_string(), resolver);
352 }
353
354 pub fn reload_all(&self) -> (usize, Vec<(String, TlsError)>) {
358 let resolvers = self.resolvers.read();
359 let mut success_count = 0;
360 let mut errors = Vec::new();
361
362 info!(
363 listener_count = resolvers.len(),
364 "Reloading certificates for all TLS listeners"
365 );
366
367 for (listener_id, resolver) in resolvers.iter() {
368 match resolver.reload() {
369 Ok(()) => {
370 success_count += 1;
371 debug!(listener_id = %listener_id, "Certificate reload successful");
372 }
373 Err(e) => {
374 error!(listener_id = %listener_id, error = %e, "Certificate reload failed");
375 errors.push((listener_id.clone(), e));
376 }
377 }
378 }
379
380 if errors.is_empty() {
381 info!(
382 success_count = success_count,
383 "All certificates reloaded successfully"
384 );
385 } else {
386 warn!(
387 success_count = success_count,
388 error_count = errors.len(),
389 "Certificate reload completed with errors"
390 );
391 }
392
393 (success_count, errors)
394 }
395
396 pub fn status(&self) -> HashMap<String, Duration> {
398 self.resolvers
399 .read()
400 .iter()
401 .map(|(id, resolver)| (id.clone(), resolver.last_reload_age()))
402 .collect()
403 }
404}
405
406impl Default for CertificateReloader {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412#[derive(Debug, Clone)]
418pub struct OcspCacheEntry {
419 pub response: Vec<u8>,
421 pub fetched_at: Instant,
423 pub expires_at: Option<Instant>,
425}
426
427pub struct OcspStapler {
431 cache: RwLock<HashMap<String, OcspCacheEntry>>,
433 refresh_interval: Duration,
435}
436
437impl OcspStapler {
438 pub fn new() -> Self {
440 Self {
441 cache: RwLock::new(HashMap::new()),
442 refresh_interval: Duration::from_secs(3600), }
444 }
445
446 pub fn with_refresh_interval(interval: Duration) -> Self {
448 Self {
449 cache: RwLock::new(HashMap::new()),
450 refresh_interval: interval,
451 }
452 }
453
454 pub fn get_response(&self, cert_fingerprint: &str) -> Option<Vec<u8>> {
456 let cache = self.cache.read();
457 if let Some(entry) = cache.get(cert_fingerprint) {
458 if entry.fetched_at.elapsed() < self.refresh_interval {
460 trace!(fingerprint = %cert_fingerprint, "OCSP cache hit");
461 return Some(entry.response.clone());
462 }
463 trace!(fingerprint = %cert_fingerprint, "OCSP cache expired");
464 }
465 None
466 }
467
468 pub fn fetch_ocsp_response(
473 &self,
474 cert_der: &[u8],
475 issuer_der: &[u8],
476 ) -> Result<Vec<u8>, TlsError> {
477 use x509_parser::prelude::*;
478
479 let (_, cert) = X509Certificate::from_der(cert_der)
481 .map_err(|e| TlsError::OcspFetch(format!("Failed to parse certificate: {}", e)))?;
482
483 let (_, issuer) = X509Certificate::from_der(issuer_der)
485 .map_err(|e| TlsError::OcspFetch(format!("Failed to parse issuer certificate: {}", e)))?;
486
487 let ocsp_url = extract_ocsp_responder_url(&cert)?;
489 debug!(url = %ocsp_url, "Found OCSP responder URL");
490
491 let ocsp_request = build_ocsp_request(&cert, &issuer)?;
493
494 let response = send_ocsp_request_sync(&ocsp_url, &ocsp_request)?;
497
498 let fingerprint = calculate_cert_fingerprint(cert_der);
500
501 let entry = OcspCacheEntry {
503 response: response.clone(),
504 fetched_at: Instant::now(),
505 expires_at: None, };
507 self.cache.write().insert(fingerprint, entry);
508
509 info!("Successfully fetched and cached OCSP response");
510 Ok(response)
511 }
512
513 pub async fn fetch_ocsp_response_async(
515 &self,
516 cert_der: &[u8],
517 issuer_der: &[u8],
518 ) -> Result<Vec<u8>, TlsError> {
519 use x509_parser::prelude::*;
520
521 let (_, cert) = X509Certificate::from_der(cert_der)
523 .map_err(|e| TlsError::OcspFetch(format!("Failed to parse certificate: {}", e)))?;
524
525 let (_, issuer) = X509Certificate::from_der(issuer_der)
527 .map_err(|e| TlsError::OcspFetch(format!("Failed to parse issuer certificate: {}", e)))?;
528
529 let ocsp_url = extract_ocsp_responder_url(&cert)?;
531 debug!(url = %ocsp_url, "Found OCSP responder URL");
532
533 let ocsp_request = build_ocsp_request(&cert, &issuer)?;
535
536 let response = send_ocsp_request_async(&ocsp_url, &ocsp_request).await?;
538
539 let fingerprint = calculate_cert_fingerprint(cert_der);
541
542 let entry = OcspCacheEntry {
544 response: response.clone(),
545 fetched_at: Instant::now(),
546 expires_at: None,
547 };
548 self.cache.write().insert(fingerprint, entry);
549
550 info!("Successfully fetched and cached OCSP response (async)");
551 Ok(response)
552 }
553
554 pub fn prefetch_for_config(&self, config: &TlsConfig) -> Vec<String> {
556 let mut warnings = Vec::new();
557
558 if !config.ocsp_stapling {
559 trace!("OCSP stapling disabled in config");
560 return warnings;
561 }
562
563 info!("Prefetching OCSP responses for certificates");
564
565 warnings.push("OCSP stapling prefetch not yet fully implemented".to_string());
568
569 warnings
570 }
571
572 pub fn clear_cache(&self) {
574 self.cache.write().clear();
575 info!("OCSP cache cleared");
576 }
577}
578
579impl Default for OcspStapler {
580 fn default() -> Self {
581 Self::new()
582 }
583}
584
585fn extract_ocsp_responder_url(cert: &x509_parser::certificate::X509Certificate) -> Result<String, TlsError> {
591 use x509_parser::prelude::*;
592
593 let aia = cert
595 .extensions()
596 .iter()
597 .find(|ext| ext.oid == oid_registry::OID_PKIX_AUTHORITY_INFO_ACCESS)
598 .ok_or_else(|| TlsError::OcspFetch(
599 "Certificate does not have Authority Information Access extension".to_string()
600 ))?;
601
602 let aia_value = match aia.parsed_extension() {
604 ParsedExtension::AuthorityInfoAccess(aia) => aia,
605 _ => return Err(TlsError::OcspFetch(
606 "Failed to parse Authority Information Access extension".to_string()
607 )),
608 };
609
610 for access in &aia_value.accessdescs {
612 if access.access_method == oid_registry::OID_PKIX_ACCESS_DESCRIPTOR_OCSP {
613 match &access.access_location {
614 GeneralName::URI(url) => {
615 return Ok(url.to_string());
616 }
617 _ => continue,
618 }
619 }
620 }
621
622 Err(TlsError::OcspFetch(
623 "Certificate AIA does not contain OCSP responder URL".to_string()
624 ))
625}
626
627fn build_ocsp_request(
631 cert: &x509_parser::certificate::X509Certificate,
632 issuer: &x509_parser::certificate::X509Certificate,
633) -> Result<Vec<u8>, TlsError> {
634 use sha2::{Sha256, Digest};
635
636 let issuer_name_hash = {
643 let mut hasher = Sha256::new();
644 hasher.update(issuer.subject().as_raw());
645 hasher.finalize()
646 };
647
648 let issuer_key_hash = {
650 let mut hasher = Sha256::new();
651 hasher.update(issuer.public_key().subject_public_key.data.as_ref());
652 hasher.finalize()
653 };
654
655 let serial = cert.serial.to_bytes_be();
657
658 let request = build_ocsp_request_der(
661 &issuer_name_hash,
662 &issuer_key_hash,
663 &serial,
664 );
665
666 Ok(request)
667}
668
669fn build_ocsp_request_der(
671 issuer_name_hash: &[u8],
672 issuer_key_hash: &[u8],
673 serial_number: &[u8],
674) -> Vec<u8> {
675 let sha256_oid: &[u8] = &[0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01];
677
678 let hash_algorithm = der_sequence(&[
680 &der_oid(sha256_oid),
681 &der_null(),
682 ]);
683
684 let cert_id = der_sequence(&[
685 &hash_algorithm,
686 &der_octet_string(issuer_name_hash),
687 &der_octet_string(issuer_key_hash),
688 &der_integer(serial_number),
689 ]);
690
691 let request = der_sequence(&[&cert_id]);
693
694 let request_list = der_sequence(&[&request]);
696
697 let tbs_request = der_sequence(&[&request_list]);
699
700 der_sequence(&[&tbs_request])
702}
703
704fn der_sequence(items: &[&[u8]]) -> Vec<u8> {
706 let mut content = Vec::new();
707 for item in items {
708 content.extend_from_slice(item);
709 }
710 let mut result = vec![0x30]; result.extend(der_length(content.len()));
712 result.extend(content);
713 result
714}
715
716fn der_oid(oid: &[u8]) -> Vec<u8> {
717 let mut result = vec![0x06]; result.extend(der_length(oid.len()));
719 result.extend_from_slice(oid);
720 result
721}
722
723fn der_null() -> Vec<u8> {
724 vec![0x05, 0x00] }
726
727fn der_octet_string(data: &[u8]) -> Vec<u8> {
728 let mut result = vec![0x04]; result.extend(der_length(data.len()));
730 result.extend_from_slice(data);
731 result
732}
733
734fn der_integer(data: &[u8]) -> Vec<u8> {
735 let mut result = vec![0x02]; let data = match data.iter().position(|&b| b != 0) {
738 Some(pos) => &data[pos..],
739 None => &[0],
740 };
741 if !data.is_empty() && data[0] & 0x80 != 0 {
743 result.extend(der_length(data.len() + 1));
744 result.push(0x00);
745 } else {
746 result.extend(der_length(data.len()));
747 }
748 result.extend_from_slice(data);
749 result
750}
751
752fn der_length(len: usize) -> Vec<u8> {
753 if len < 128 {
754 vec![len as u8]
755 } else if len < 256 {
756 vec![0x81, len as u8]
757 } else {
758 vec![0x82, (len >> 8) as u8, len as u8]
759 }
760}
761
762fn send_ocsp_request_sync(url: &str, request: &[u8]) -> Result<Vec<u8>, TlsError> {
764 use std::io::{Read, Write};
765 use std::net::TcpStream;
766 use std::time::Duration;
767
768 let url = url::Url::parse(url)
770 .map_err(|e| TlsError::OcspFetch(format!("Invalid OCSP URL: {}", e)))?;
771
772 let host = url.host_str()
773 .ok_or_else(|| TlsError::OcspFetch("OCSP URL has no host".to_string()))?;
774 let port = url.port().unwrap_or(80);
775 let path = if url.path().is_empty() { "/" } else { url.path() };
776
777 let addr = format!("{}:{}", host, port);
779 let mut stream = TcpStream::connect(&addr)
780 .map_err(|e| TlsError::OcspFetch(format!("Failed to connect to OCSP responder: {}", e)))?;
781
782 stream.set_read_timeout(Some(Duration::from_secs(10)))
783 .map_err(|e| TlsError::OcspFetch(format!("Failed to set timeout: {}", e)))?;
784 stream.set_write_timeout(Some(Duration::from_secs(10)))
785 .map_err(|e| TlsError::OcspFetch(format!("Failed to set timeout: {}", e)))?;
786
787 let http_request = format!(
789 "POST {} HTTP/1.1\r\n\
790 Host: {}\r\n\
791 Content-Type: application/ocsp-request\r\n\
792 Content-Length: {}\r\n\
793 Connection: close\r\n\
794 \r\n",
795 path, host, request.len()
796 );
797
798 stream.write_all(http_request.as_bytes())
800 .map_err(|e| TlsError::OcspFetch(format!("Failed to send OCSP request: {}", e)))?;
801 stream.write_all(request)
802 .map_err(|e| TlsError::OcspFetch(format!("Failed to send OCSP request body: {}", e)))?;
803
804 let mut response = Vec::new();
806 stream.read_to_end(&mut response)
807 .map_err(|e| TlsError::OcspFetch(format!("Failed to read OCSP response: {}", e)))?;
808
809 let headers_end = response.windows(4)
811 .position(|w| w == b"\r\n\r\n")
812 .ok_or_else(|| TlsError::OcspFetch("Invalid HTTP response: no headers end".to_string()))?;
813
814 let body = &response[headers_end + 4..];
815 if body.is_empty() {
816 return Err(TlsError::OcspFetch("Empty OCSP response body".to_string()));
817 }
818
819 Ok(body.to_vec())
820}
821
822async fn send_ocsp_request_async(url: &str, request: &[u8]) -> Result<Vec<u8>, TlsError> {
824 let client = reqwest::Client::builder()
825 .timeout(Duration::from_secs(10))
826 .build()
827 .map_err(|e| TlsError::OcspFetch(format!("Failed to create HTTP client: {}", e)))?;
828
829 let response = client
830 .post(url)
831 .header("Content-Type", "application/ocsp-request")
832 .body(request.to_vec())
833 .send()
834 .await
835 .map_err(|e| TlsError::OcspFetch(format!("OCSP request failed: {}", e)))?;
836
837 if !response.status().is_success() {
838 return Err(TlsError::OcspFetch(format!(
839 "OCSP responder returned status: {}",
840 response.status()
841 )));
842 }
843
844 let body = response.bytes().await
845 .map_err(|e| TlsError::OcspFetch(format!("Failed to read OCSP response: {}", e)))?;
846
847 Ok(body.to_vec())
848}
849
850fn calculate_cert_fingerprint(cert_der: &[u8]) -> String {
852 use sha2::{Sha256, Digest};
853 let mut hasher = Sha256::new();
854 hasher.update(cert_der);
855 let result = hasher.finalize();
856 hex::encode(result)
857}
858
859pub fn load_client_cert_key(
877 cert_path: &Path,
878 key_path: &Path,
879) -> Result<Arc<pingora_core::utils::tls::CertKey>, TlsError> {
880 let cert_file = File::open(cert_path)
882 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
883 let mut cert_reader = BufReader::new(cert_file);
884
885 let cert_ders: Vec<Vec<u8>> = rustls_pemfile::certs(&mut cert_reader)
887 .collect::<Result<Vec<_>, _>>()
888 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?
889 .into_iter()
890 .map(|c| c.to_vec())
891 .collect();
892
893 if cert_ders.is_empty() {
894 return Err(TlsError::CertificateLoad(format!(
895 "{}: No certificates found in PEM file",
896 cert_path.display()
897 )));
898 }
899
900 let key_file = File::open(key_path)
902 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
903 let mut key_reader = BufReader::new(key_file);
904
905 let key_der = rustls_pemfile::private_key(&mut key_reader)
907 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
908 .ok_or_else(|| {
909 TlsError::KeyLoad(format!(
910 "{}: No private key found in PEM file",
911 key_path.display()
912 ))
913 })?
914 .secret_der()
915 .to_vec();
916
917 let cert_key = pingora_core::utils::tls::CertKey::new(cert_ders, key_der);
919
920 debug!(
921 cert_path = %cert_path.display(),
922 key_path = %key_path.display(),
923 "Loaded mTLS client certificate for upstream connections"
924 );
925
926 Ok(Arc::new(cert_key))
927}
928
929pub fn build_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<ClientConfig, TlsError> {
934 let mut root_store = RootCertStore::empty();
935
936 if let Some(ca_path) = &config.ca_cert {
938 let ca_file = File::open(ca_path)
939 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
940 let mut ca_reader = BufReader::new(ca_file);
941
942 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
943 .collect::<Result<Vec<_>, _>>()
944 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
945
946 for cert in certs {
947 root_store.add(cert).map_err(|e| {
948 TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
949 })?;
950 }
951
952 debug!(
953 ca_file = %ca_path.display(),
954 cert_count = root_store.len(),
955 "Loaded upstream CA certificates"
956 );
957 } else if !config.insecure_skip_verify {
958 root_store = RootCertStore {
960 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
961 };
962 trace!("Using webpki-roots for upstream TLS verification");
963 }
964
965 let builder = ClientConfig::builder().with_root_certificates(root_store);
967
968 let client_config = if let (Some(cert_path), Some(key_path)) =
969 (&config.client_cert, &config.client_key)
970 {
971 let cert_file = File::open(cert_path)
973 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
974 let mut cert_reader = BufReader::new(cert_file);
975
976 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
977 .collect::<Result<Vec<_>, _>>()
978 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
979
980 if certs.is_empty() {
981 return Err(TlsError::CertificateLoad(format!(
982 "{}: No certificates found",
983 cert_path.display()
984 )));
985 }
986
987 let key_file = File::open(key_path)
989 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
990 let mut key_reader = BufReader::new(key_file);
991
992 let key = rustls_pemfile::private_key(&mut key_reader)
993 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
994 .ok_or_else(|| {
995 TlsError::KeyLoad(format!("{}: No private key found", key_path.display()))
996 })?;
997
998 info!(
999 cert_file = %cert_path.display(),
1000 "Configured mTLS client certificate for upstream connections"
1001 );
1002
1003 builder
1004 .with_client_auth_cert(certs, key)
1005 .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to set client auth: {}", e)))?
1006 } else {
1007 builder.with_no_client_auth()
1009 };
1010
1011 debug!("Upstream TLS configuration built successfully");
1012 Ok(client_config)
1013}
1014
1015pub fn validate_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<(), TlsError> {
1017 if let Some(ca_path) = &config.ca_cert {
1019 if !ca_path.exists() {
1020 return Err(TlsError::CertificateLoad(format!(
1021 "Upstream CA certificate not found: {}",
1022 ca_path.display()
1023 )));
1024 }
1025 }
1026
1027 if let Some(cert_path) = &config.client_cert {
1029 if !cert_path.exists() {
1030 return Err(TlsError::CertificateLoad(format!(
1031 "Upstream client certificate not found: {}",
1032 cert_path.display()
1033 )));
1034 }
1035
1036 match &config.client_key {
1038 Some(key_path) if !key_path.exists() => {
1039 return Err(TlsError::KeyLoad(format!(
1040 "Upstream client key not found: {}",
1041 key_path.display()
1042 )));
1043 }
1044 None => {
1045 return Err(TlsError::ConfigBuild(
1046 "client_cert specified without client_key".to_string(),
1047 ));
1048 }
1049 _ => {}
1050 }
1051 }
1052
1053 if config.client_key.is_some() && config.client_cert.is_none() {
1054 return Err(TlsError::ConfigBuild(
1055 "client_key specified without client_cert".to_string(),
1056 ));
1057 }
1058
1059 Ok(())
1060}
1061
1062fn load_certified_key(cert_path: &Path, key_path: &Path) -> Result<CertifiedKey, TlsError> {
1068 let cert_file = File::open(cert_path)
1070 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1071 let mut cert_reader = BufReader::new(cert_file);
1072
1073 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
1074 .collect::<Result<Vec<_>, _>>()
1075 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1076
1077 if certs.is_empty() {
1078 return Err(TlsError::CertificateLoad(format!(
1079 "{}: No certificates found in file",
1080 cert_path.display()
1081 )));
1082 }
1083
1084 let key_file = File::open(key_path)
1086 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
1087 let mut key_reader = BufReader::new(key_file);
1088
1089 let key = rustls_pemfile::private_key(&mut key_reader)
1090 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
1091 .ok_or_else(|| {
1092 TlsError::KeyLoad(format!(
1093 "{}: No private key found in file",
1094 key_path.display()
1095 ))
1096 })?;
1097
1098 let provider = rustls::crypto::CryptoProvider::get_default()
1100 .cloned()
1101 .unwrap_or_else(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider()));
1102
1103 let signing_key = provider
1104 .key_provider
1105 .load_private_key(key)
1106 .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to load private key: {:?}", e)))?;
1107
1108 Ok(CertifiedKey::new(certs, signing_key))
1109}
1110
1111pub fn load_client_ca(ca_path: &Path) -> Result<RootCertStore, TlsError> {
1113 let ca_file = File::open(ca_path)
1114 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
1115 let mut ca_reader = BufReader::new(ca_file);
1116
1117 let mut root_store = RootCertStore::empty();
1118
1119 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
1120 .collect::<Result<Vec<_>, _>>()
1121 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
1122
1123 for cert in certs {
1124 root_store.add(cert).map_err(|e| {
1125 TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
1126 })?;
1127 }
1128
1129 if root_store.is_empty() {
1130 return Err(TlsError::CertificateLoad(format!(
1131 "{}: No CA certificates found",
1132 ca_path.display()
1133 )));
1134 }
1135
1136 info!(
1137 ca_file = %ca_path.display(),
1138 cert_count = root_store.len(),
1139 "Loaded client CA certificates"
1140 );
1141
1142 Ok(root_store)
1143}
1144
1145pub fn build_server_config(config: &TlsConfig) -> Result<ServerConfig, TlsError> {
1147 let resolver = SniResolver::from_config(config)?;
1148
1149 let builder = ServerConfig::builder();
1150
1151 let server_config = if config.client_auth {
1153 if let Some(ca_path) = &config.ca_file {
1154 let root_store = load_client_ca(ca_path)?;
1155 let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
1156 .build()
1157 .map_err(|e| {
1158 TlsError::ConfigBuild(format!("Failed to build client verifier: {}", e))
1159 })?;
1160
1161 info!("mTLS enabled: client certificates required");
1162
1163 builder
1164 .with_client_cert_verifier(verifier)
1165 .with_cert_resolver(Arc::new(resolver))
1166 } else {
1167 warn!("client_auth enabled but no ca_file specified, disabling client auth");
1168 builder
1169 .with_no_client_auth()
1170 .with_cert_resolver(Arc::new(resolver))
1171 }
1172 } else {
1173 builder
1174 .with_no_client_auth()
1175 .with_cert_resolver(Arc::new(resolver))
1176 };
1177
1178 let mut config = server_config;
1180 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
1181
1182 debug!("TLS configuration built successfully");
1183
1184 Ok(config)
1185}
1186
1187pub fn validate_tls_config(config: &TlsConfig) -> Result<(), TlsError> {
1189 if config.acme.is_some() {
1191 trace!("Skipping manual cert validation for ACME-managed TLS");
1193 } else {
1194 match (&config.cert_file, &config.key_file) {
1196 (Some(cert_file), Some(key_file)) => {
1197 if !cert_file.exists() {
1198 return Err(TlsError::CertificateLoad(format!(
1199 "Certificate file not found: {}",
1200 cert_file.display()
1201 )));
1202 }
1203 if !key_file.exists() {
1204 return Err(TlsError::KeyLoad(format!(
1205 "Key file not found: {}",
1206 key_file.display()
1207 )));
1208 }
1209 }
1210 _ => {
1211 return Err(TlsError::ConfigBuild(
1212 "TLS configuration requires cert_file and key_file (or ACME block)".to_string(),
1213 ));
1214 }
1215 }
1216 }
1217
1218 for sni in &config.additional_certs {
1220 if !sni.cert_file.exists() {
1221 return Err(TlsError::CertificateLoad(format!(
1222 "SNI certificate file not found: {}",
1223 sni.cert_file.display()
1224 )));
1225 }
1226 if !sni.key_file.exists() {
1227 return Err(TlsError::KeyLoad(format!(
1228 "SNI key file not found: {}",
1229 sni.key_file.display()
1230 )));
1231 }
1232 }
1233
1234 if config.client_auth {
1236 if let Some(ca_path) = &config.ca_file {
1237 if !ca_path.exists() {
1238 return Err(TlsError::CertificateLoad(format!(
1239 "CA certificate file not found: {}",
1240 ca_path.display()
1241 )));
1242 }
1243 }
1244 }
1245
1246 Ok(())
1247}
1248
1249#[cfg(test)]
1250mod tests {
1251
1252 #[test]
1253 fn test_wildcard_matching() {
1254 let name = "foo.bar.example.com";
1257 let parts: Vec<&str> = name.split('.').collect();
1258
1259 assert_eq!(parts.len(), 4);
1260
1261 let domain1 = parts[1..].join(".");
1263 assert_eq!(domain1, "bar.example.com");
1264
1265 let domain2 = parts[2..].join(".");
1266 assert_eq!(domain2, "example.com");
1267 }
1268
1269 #[test]
1270 fn test_hostname_normalization() {
1271 let hostname = "Example.COM";
1272 let normalized = hostname.to_lowercase();
1273 assert_eq!(normalized, "example.com");
1274 }
1275}