1use std::collections::HashMap;
49use std::fs::File;
50use std::io::BufReader;
51use std::path::Path;
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54
55use parking_lot::RwLock;
56use rustls::client::ClientConfig;
57use rustls::pki_types::CertificateDer;
58use rustls::server::{ClientHello, ResolvesServerCert};
59use rustls::sign::CertifiedKey;
60use rustls::{RootCertStore, ServerConfig};
61use tracing::{debug, error, info, trace, warn};
62
63use grapsus_config::{TlsConfig, UpstreamTlsConfig};
64
65#[derive(Debug)]
67pub enum TlsError {
68 CertificateLoad(String),
70 KeyLoad(String),
72 ConfigBuild(String),
74 CertKeyMismatch(String),
76 InvalidCertificate(String),
78 OcspFetch(String),
80}
81
82impl std::fmt::Display for TlsError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 TlsError::CertificateLoad(e) => write!(f, "Failed to load certificate: {}", e),
86 TlsError::KeyLoad(e) => write!(f, "Failed to load private key: {}", e),
87 TlsError::ConfigBuild(e) => write!(f, "Failed to build TLS config: {}", e),
88 TlsError::CertKeyMismatch(e) => write!(f, "Certificate/key mismatch: {}", e),
89 TlsError::InvalidCertificate(e) => write!(f, "Invalid certificate: {}", e),
90 TlsError::OcspFetch(e) => write!(f, "Failed to fetch OCSP response: {}", e),
91 }
92 }
93}
94
95impl std::error::Error for TlsError {}
96
97#[derive(Debug)]
105pub struct SniResolver {
106 default_cert: Arc<CertifiedKey>,
108 sni_certs: HashMap<String, Arc<CertifiedKey>>,
111 wildcard_certs: HashMap<String, Arc<CertifiedKey>>,
113}
114
115impl SniResolver {
116 pub fn from_config(config: &TlsConfig) -> Result<Self, TlsError> {
118 let (cert_path_buf, key_path_buf);
120 let (cert_file, key_file) = match (&config.cert_file, &config.key_file) {
121 (Some(cert), Some(key)) => (cert.as_path(), key.as_path()),
122 _ if config.acme.is_some() => {
123 let acme = config.acme.as_ref().unwrap();
124 let primary = acme.domains.first().ok_or_else(|| {
125 TlsError::ConfigBuild(
126 "ACME configuration has no domains for cert path resolution".to_string(),
127 )
128 })?;
129 cert_path_buf = acme.storage.join("domains").join(primary).join("cert.pem");
130 key_path_buf = acme.storage.join("domains").join(primary).join("key.pem");
131 (cert_path_buf.as_path(), key_path_buf.as_path())
132 }
133 _ => {
134 return Err(TlsError::ConfigBuild(
135 "TLS configuration requires cert_file and key_file (or ACME block)".to_string(),
136 ));
137 }
138 };
139
140 let default_cert = load_certified_key(cert_file, key_file)?;
142
143 info!(
144 cert_file = %cert_file.display(),
145 "Loaded default TLS certificate"
146 );
147
148 let mut sni_certs = HashMap::new();
149 let mut wildcard_certs = HashMap::new();
150
151 for sni_config in &config.additional_certs {
153 let cert = load_certified_key(&sni_config.cert_file, &sni_config.key_file)?;
154 let cert = Arc::new(cert);
155
156 for hostname in &sni_config.hostnames {
157 let hostname_lower = hostname.to_lowercase();
158
159 if hostname_lower.starts_with("*.") {
160 let domain = hostname_lower.strip_prefix("*.").unwrap().to_string();
162 wildcard_certs.insert(domain.clone(), cert.clone());
163 debug!(
164 pattern = %hostname,
165 domain = %domain,
166 cert_file = %sni_config.cert_file.display(),
167 "Registered wildcard SNI certificate"
168 );
169 } else {
170 sni_certs.insert(hostname_lower.clone(), cert.clone());
172 debug!(
173 hostname = %hostname_lower,
174 cert_file = %sni_config.cert_file.display(),
175 "Registered SNI certificate"
176 );
177 }
178 }
179 }
180
181 info!(
182 exact_certs = sni_certs.len(),
183 wildcard_certs = wildcard_certs.len(),
184 "SNI resolver initialized"
185 );
186
187 Ok(Self {
188 default_cert: Arc::new(default_cert),
189 sni_certs,
190 wildcard_certs,
191 })
192 }
193
194 pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
199 let Some(name) = server_name else {
200 debug!("No SNI provided, using default certificate");
201 return self.default_cert.clone();
202 };
203
204 let name_lower = name.to_lowercase();
205
206 if let Some(cert) = self.sni_certs.get(&name_lower) {
208 debug!(hostname = %name_lower, "SNI exact match found");
209 return cert.clone();
210 }
211
212 let parts: Vec<&str> = name_lower.split('.').collect();
215 for i in 1..parts.len() {
216 let domain = parts[i..].join(".");
217 if let Some(cert) = self.wildcard_certs.get(&domain) {
218 debug!(
219 hostname = %name_lower,
220 wildcard_domain = %domain,
221 "SNI wildcard match found"
222 );
223 return cert.clone();
224 }
225 }
226
227 debug!(
228 hostname = %name_lower,
229 "No SNI match found, using default certificate"
230 );
231 self.default_cert.clone()
232 }
233}
234
235impl ResolvesServerCert for SniResolver {
236 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
237 Some(self.resolve(client_hello.server_name()))
238 }
239}
240
241pub struct HotReloadableSniResolver {
251 inner: RwLock<Arc<SniResolver>>,
253 config: RwLock<TlsConfig>,
255 last_reload: RwLock<Instant>,
257}
258
259impl std::fmt::Debug for HotReloadableSniResolver {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 f.debug_struct("HotReloadableSniResolver")
262 .field("last_reload", &*self.last_reload.read())
263 .finish()
264 }
265}
266
267impl HotReloadableSniResolver {
268 pub fn from_config(config: TlsConfig) -> Result<Self, TlsError> {
270 let resolver = SniResolver::from_config(&config)?;
271
272 Ok(Self {
273 inner: RwLock::new(Arc::new(resolver)),
274 config: RwLock::new(config),
275 last_reload: RwLock::new(Instant::now()),
276 })
277 }
278
279 pub fn reload(&self) -> Result<(), TlsError> {
284 let config = self.config.read();
285
286 let cert_file_display = config
287 .cert_file
288 .as_ref()
289 .map(|p| p.display().to_string())
290 .unwrap_or_else(|| "(acme-managed)".to_string());
291
292 info!(
293 cert_file = %cert_file_display,
294 sni_count = config.additional_certs.len(),
295 "Reloading TLS certificates"
296 );
297
298 let new_resolver = SniResolver::from_config(&config)?;
300
301 *self.inner.write() = Arc::new(new_resolver);
303 *self.last_reload.write() = Instant::now();
304
305 info!("TLS certificates reloaded successfully");
306 Ok(())
307 }
308
309 pub fn update_config(&self, new_config: TlsConfig) -> Result<(), TlsError> {
311 let new_resolver = SniResolver::from_config(&new_config)?;
313
314 *self.config.write() = new_config;
316 *self.inner.write() = Arc::new(new_resolver);
317 *self.last_reload.write() = Instant::now();
318
319 info!("TLS configuration updated and certificates reloaded");
320 Ok(())
321 }
322
323 pub fn last_reload_age(&self) -> Duration {
325 self.last_reload.read().elapsed()
326 }
327
328 pub fn resolve(&self, server_name: Option<&str>) -> Arc<CertifiedKey> {
332 self.inner.read().resolve(server_name)
333 }
334}
335
336impl ResolvesServerCert for HotReloadableSniResolver {
337 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
338 Some(self.inner.read().resolve(client_hello.server_name()))
339 }
340}
341
342pub struct CertificateReloader {
346 resolvers: RwLock<HashMap<String, Arc<HotReloadableSniResolver>>>,
348}
349
350impl CertificateReloader {
351 pub fn new() -> Self {
353 Self {
354 resolvers: RwLock::new(HashMap::new()),
355 }
356 }
357
358 pub fn register(&self, listener_id: &str, resolver: Arc<HotReloadableSniResolver>) {
360 debug!(listener_id = %listener_id, "Registering TLS resolver for hot-reload");
361 self.resolvers
362 .write()
363 .insert(listener_id.to_string(), resolver);
364 }
365
366 pub fn reload_all(&self) -> (usize, Vec<(String, TlsError)>) {
370 let resolvers = self.resolvers.read();
371 let mut success_count = 0;
372 let mut errors = Vec::new();
373
374 info!(
375 listener_count = resolvers.len(),
376 "Reloading certificates for all TLS listeners"
377 );
378
379 for (listener_id, resolver) in resolvers.iter() {
380 match resolver.reload() {
381 Ok(()) => {
382 success_count += 1;
383 debug!(listener_id = %listener_id, "Certificate reload successful");
384 }
385 Err(e) => {
386 error!(listener_id = %listener_id, error = %e, "Certificate reload failed");
387 errors.push((listener_id.clone(), e));
388 }
389 }
390 }
391
392 if errors.is_empty() {
393 info!(
394 success_count = success_count,
395 "All certificates reloaded successfully"
396 );
397 } else {
398 warn!(
399 success_count = success_count,
400 error_count = errors.len(),
401 "Certificate reload completed with errors"
402 );
403 }
404
405 (success_count, errors)
406 }
407
408 pub fn status(&self) -> HashMap<String, Duration> {
410 self.resolvers
411 .read()
412 .iter()
413 .map(|(id, resolver)| (id.clone(), resolver.last_reload_age()))
414 .collect()
415 }
416}
417
418impl Default for CertificateReloader {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424#[derive(Debug, Clone)]
430pub struct OcspCacheEntry {
431 pub response: Vec<u8>,
433 pub fetched_at: Instant,
435 pub expires_at: Option<Instant>,
437}
438
439pub struct OcspStapler {
443 cache: RwLock<HashMap<String, OcspCacheEntry>>,
445 refresh_interval: Duration,
447}
448
449impl OcspStapler {
450 pub fn new() -> Self {
452 Self {
453 cache: RwLock::new(HashMap::new()),
454 refresh_interval: Duration::from_secs(3600), }
456 }
457
458 pub fn with_refresh_interval(interval: Duration) -> Self {
460 Self {
461 cache: RwLock::new(HashMap::new()),
462 refresh_interval: interval,
463 }
464 }
465
466 pub fn get_response(&self, cert_fingerprint: &str) -> Option<Vec<u8>> {
468 let cache = self.cache.read();
469 if let Some(entry) = cache.get(cert_fingerprint) {
470 if entry.fetched_at.elapsed() < self.refresh_interval {
472 trace!(fingerprint = %cert_fingerprint, "OCSP cache hit");
473 return Some(entry.response.clone());
474 }
475 trace!(fingerprint = %cert_fingerprint, "OCSP cache expired");
476 }
477 None
478 }
479
480 pub fn fetch_ocsp_response(
485 &self,
486 cert_der: &[u8],
487 issuer_der: &[u8],
488 ) -> Result<Vec<u8>, TlsError> {
489 use x509_parser::prelude::*;
490
491 let (_, cert) = X509Certificate::from_der(cert_der)
493 .map_err(|e| TlsError::OcspFetch(format!("Failed to parse certificate: {}", e)))?;
494
495 let (_, issuer) = X509Certificate::from_der(issuer_der).map_err(|e| {
497 TlsError::OcspFetch(format!("Failed to parse issuer certificate: {}", e))
498 })?;
499
500 let ocsp_url = extract_ocsp_responder_url(&cert)?;
502 debug!(url = %ocsp_url, "Found OCSP responder URL");
503
504 let ocsp_request = build_ocsp_request(&cert, &issuer)?;
506
507 let response = send_ocsp_request_sync(&ocsp_url, &ocsp_request)?;
510
511 let fingerprint = calculate_cert_fingerprint(cert_der);
513
514 let entry = OcspCacheEntry {
516 response: response.clone(),
517 fetched_at: Instant::now(),
518 expires_at: None, };
520 self.cache.write().insert(fingerprint, entry);
521
522 info!("Successfully fetched and cached OCSP response");
523 Ok(response)
524 }
525
526 pub async fn fetch_ocsp_response_async(
528 &self,
529 cert_der: &[u8],
530 issuer_der: &[u8],
531 ) -> Result<Vec<u8>, TlsError> {
532 use x509_parser::prelude::*;
533
534 let (_, cert) = X509Certificate::from_der(cert_der)
536 .map_err(|e| TlsError::OcspFetch(format!("Failed to parse certificate: {}", e)))?;
537
538 let (_, issuer) = X509Certificate::from_der(issuer_der).map_err(|e| {
540 TlsError::OcspFetch(format!("Failed to parse issuer certificate: {}", e))
541 })?;
542
543 let ocsp_url = extract_ocsp_responder_url(&cert)?;
545 debug!(url = %ocsp_url, "Found OCSP responder URL");
546
547 let ocsp_request = build_ocsp_request(&cert, &issuer)?;
549
550 let response = send_ocsp_request_async(&ocsp_url, &ocsp_request).await?;
552
553 let fingerprint = calculate_cert_fingerprint(cert_der);
555
556 let entry = OcspCacheEntry {
558 response: response.clone(),
559 fetched_at: Instant::now(),
560 expires_at: None,
561 };
562 self.cache.write().insert(fingerprint, entry);
563
564 info!("Successfully fetched and cached OCSP response (async)");
565 Ok(response)
566 }
567
568 pub fn prefetch_for_config(&self, config: &TlsConfig) -> Vec<String> {
570 let mut warnings = Vec::new();
571
572 if !config.ocsp_stapling {
573 trace!("OCSP stapling disabled in config");
574 return warnings;
575 }
576
577 info!("Prefetching OCSP responses for certificates");
578
579 warnings.push("OCSP stapling prefetch not yet fully implemented".to_string());
582
583 warnings
584 }
585
586 pub fn clear_cache(&self) {
588 self.cache.write().clear();
589 info!("OCSP cache cleared");
590 }
591}
592
593impl Default for OcspStapler {
594 fn default() -> Self {
595 Self::new()
596 }
597}
598
599fn extract_ocsp_responder_url(
605 cert: &x509_parser::certificate::X509Certificate,
606) -> Result<String, TlsError> {
607 use x509_parser::prelude::*;
608
609 let aia = cert
611 .extensions()
612 .iter()
613 .find(|ext| ext.oid == oid_registry::OID_PKIX_AUTHORITY_INFO_ACCESS)
614 .ok_or_else(|| {
615 TlsError::OcspFetch(
616 "Certificate does not have Authority Information Access extension".to_string(),
617 )
618 })?;
619
620 let aia_value = match aia.parsed_extension() {
622 ParsedExtension::AuthorityInfoAccess(aia) => aia,
623 _ => {
624 return Err(TlsError::OcspFetch(
625 "Failed to parse Authority Information Access extension".to_string(),
626 ))
627 }
628 };
629
630 for access in &aia_value.accessdescs {
632 if access.access_method == oid_registry::OID_PKIX_ACCESS_DESCRIPTOR_OCSP {
633 match &access.access_location {
634 GeneralName::URI(url) => {
635 return Ok(url.to_string());
636 }
637 _ => continue,
638 }
639 }
640 }
641
642 Err(TlsError::OcspFetch(
643 "Certificate AIA does not contain OCSP responder URL".to_string(),
644 ))
645}
646
647fn build_ocsp_request(
651 cert: &x509_parser::certificate::X509Certificate,
652 issuer: &x509_parser::certificate::X509Certificate,
653) -> Result<Vec<u8>, TlsError> {
654 use sha2::{Digest, Sha256};
655
656 let issuer_name_hash = {
663 let mut hasher = Sha256::new();
664 hasher.update(issuer.subject().as_raw());
665 hasher.finalize()
666 };
667
668 let issuer_key_hash = {
670 let mut hasher = Sha256::new();
671 hasher.update(issuer.public_key().subject_public_key.data.as_ref());
672 hasher.finalize()
673 };
674
675 let serial = cert.serial.to_bytes_be();
677
678 let request = build_ocsp_request_der(&issuer_name_hash, &issuer_key_hash, &serial);
681
682 Ok(request)
683}
684
685fn build_ocsp_request_der(
687 issuer_name_hash: &[u8],
688 issuer_key_hash: &[u8],
689 serial_number: &[u8],
690) -> Vec<u8> {
691 let sha256_oid: &[u8] = &[0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01];
693
694 let hash_algorithm = der_sequence(&[&der_oid(sha256_oid), &der_null()]);
696
697 let cert_id = der_sequence(&[
698 &hash_algorithm,
699 &der_octet_string(issuer_name_hash),
700 &der_octet_string(issuer_key_hash),
701 &der_integer(serial_number),
702 ]);
703
704 let request = der_sequence(&[&cert_id]);
706
707 let request_list = der_sequence(&[&request]);
709
710 let tbs_request = der_sequence(&[&request_list]);
712
713 der_sequence(&[&tbs_request])
715}
716
717fn der_sequence(items: &[&[u8]]) -> Vec<u8> {
719 let mut content = Vec::new();
720 for item in items {
721 content.extend_from_slice(item);
722 }
723 let mut result = vec![0x30]; result.extend(der_length(content.len()));
725 result.extend(content);
726 result
727}
728
729fn der_oid(oid: &[u8]) -> Vec<u8> {
730 let mut result = vec![0x06]; result.extend(der_length(oid.len()));
732 result.extend_from_slice(oid);
733 result
734}
735
736fn der_null() -> Vec<u8> {
737 vec![0x05, 0x00] }
739
740fn der_octet_string(data: &[u8]) -> Vec<u8> {
741 let mut result = vec![0x04]; result.extend(der_length(data.len()));
743 result.extend_from_slice(data);
744 result
745}
746
747fn der_integer(data: &[u8]) -> Vec<u8> {
748 let mut result = vec![0x02]; let data = match data.iter().position(|&b| b != 0) {
751 Some(pos) => &data[pos..],
752 None => &[0],
753 };
754 if !data.is_empty() && data[0] & 0x80 != 0 {
756 result.extend(der_length(data.len() + 1));
757 result.push(0x00);
758 } else {
759 result.extend(der_length(data.len()));
760 }
761 result.extend_from_slice(data);
762 result
763}
764
765fn der_length(len: usize) -> Vec<u8> {
766 if len < 128 {
767 vec![len as u8]
768 } else if len < 256 {
769 vec![0x81, len as u8]
770 } else {
771 vec![0x82, (len >> 8) as u8, len as u8]
772 }
773}
774
775fn send_ocsp_request_sync(url: &str, request: &[u8]) -> Result<Vec<u8>, TlsError> {
777 use std::io::{Read, Write};
778 use std::net::TcpStream;
779 use std::time::Duration;
780
781 let url = url::Url::parse(url)
783 .map_err(|e| TlsError::OcspFetch(format!("Invalid OCSP URL: {}", e)))?;
784
785 let host = url
786 .host_str()
787 .ok_or_else(|| TlsError::OcspFetch("OCSP URL has no host".to_string()))?;
788 let port = url.port().unwrap_or(80);
789 let path = if url.path().is_empty() {
790 "/"
791 } else {
792 url.path()
793 };
794
795 let addr = format!("{}:{}", host, port);
797 let mut stream = TcpStream::connect(&addr)
798 .map_err(|e| TlsError::OcspFetch(format!("Failed to connect to OCSP responder: {}", e)))?;
799
800 stream
801 .set_read_timeout(Some(Duration::from_secs(10)))
802 .map_err(|e| TlsError::OcspFetch(format!("Failed to set timeout: {}", e)))?;
803 stream
804 .set_write_timeout(Some(Duration::from_secs(10)))
805 .map_err(|e| TlsError::OcspFetch(format!("Failed to set timeout: {}", e)))?;
806
807 let http_request = format!(
809 "POST {} HTTP/1.1\r\n\
810 Host: {}\r\n\
811 Content-Type: application/ocsp-request\r\n\
812 Content-Length: {}\r\n\
813 Connection: close\r\n\
814 \r\n",
815 path,
816 host,
817 request.len()
818 );
819
820 stream
822 .write_all(http_request.as_bytes())
823 .map_err(|e| TlsError::OcspFetch(format!("Failed to send OCSP request: {}", e)))?;
824 stream
825 .write_all(request)
826 .map_err(|e| TlsError::OcspFetch(format!("Failed to send OCSP request body: {}", e)))?;
827
828 let mut response = Vec::new();
830 stream
831 .read_to_end(&mut response)
832 .map_err(|e| TlsError::OcspFetch(format!("Failed to read OCSP response: {}", e)))?;
833
834 let headers_end = response
836 .windows(4)
837 .position(|w| w == b"\r\n\r\n")
838 .ok_or_else(|| TlsError::OcspFetch("Invalid HTTP response: no headers end".to_string()))?;
839
840 let body = &response[headers_end + 4..];
841 if body.is_empty() {
842 return Err(TlsError::OcspFetch("Empty OCSP response body".to_string()));
843 }
844
845 Ok(body.to_vec())
846}
847
848async fn send_ocsp_request_async(url: &str, request: &[u8]) -> Result<Vec<u8>, TlsError> {
850 let client = reqwest::Client::builder()
851 .timeout(Duration::from_secs(10))
852 .build()
853 .map_err(|e| TlsError::OcspFetch(format!("Failed to create HTTP client: {}", e)))?;
854
855 let response = client
856 .post(url)
857 .header("Content-Type", "application/ocsp-request")
858 .body(request.to_vec())
859 .send()
860 .await
861 .map_err(|e| TlsError::OcspFetch(format!("OCSP request failed: {}", e)))?;
862
863 if !response.status().is_success() {
864 return Err(TlsError::OcspFetch(format!(
865 "OCSP responder returned status: {}",
866 response.status()
867 )));
868 }
869
870 let body = response
871 .bytes()
872 .await
873 .map_err(|e| TlsError::OcspFetch(format!("Failed to read OCSP response: {}", e)))?;
874
875 Ok(body.to_vec())
876}
877
878fn calculate_cert_fingerprint(cert_der: &[u8]) -> String {
880 use sha2::{Digest, Sha256};
881 let mut hasher = Sha256::new();
882 hasher.update(cert_der);
883 let result = hasher.finalize();
884 hex::encode(result)
885}
886
887pub fn load_client_cert_key(
905 cert_path: &Path,
906 key_path: &Path,
907) -> Result<Arc<pingora_core::utils::tls::CertKey>, TlsError> {
908 let cert_file = File::open(cert_path)
910 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
911 let mut cert_reader = BufReader::new(cert_file);
912
913 let cert_ders: Vec<Vec<u8>> = rustls_pemfile::certs(&mut cert_reader)
915 .collect::<Result<Vec<_>, _>>()
916 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?
917 .into_iter()
918 .map(|c| c.to_vec())
919 .collect();
920
921 if cert_ders.is_empty() {
922 return Err(TlsError::CertificateLoad(format!(
923 "{}: No certificates found in PEM file",
924 cert_path.display()
925 )));
926 }
927
928 let key_file = File::open(key_path)
930 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
931 let mut key_reader = BufReader::new(key_file);
932
933 let key_der = rustls_pemfile::private_key(&mut key_reader)
935 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
936 .ok_or_else(|| {
937 TlsError::KeyLoad(format!(
938 "{}: No private key found in PEM file",
939 key_path.display()
940 ))
941 })?
942 .secret_der()
943 .to_vec();
944
945 let cert_key = pingora_core::utils::tls::CertKey::new(cert_ders, key_der);
947
948 debug!(
949 cert_path = %cert_path.display(),
950 key_path = %key_path.display(),
951 "Loaded mTLS client certificate for upstream connections"
952 );
953
954 Ok(Arc::new(cert_key))
955}
956
957pub fn build_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<ClientConfig, TlsError> {
962 let mut root_store = RootCertStore::empty();
963
964 if let Some(ca_path) = &config.ca_cert {
966 let ca_file = File::open(ca_path)
967 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
968 let mut ca_reader = BufReader::new(ca_file);
969
970 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
971 .collect::<Result<Vec<_>, _>>()
972 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
973
974 for cert in certs {
975 root_store.add(cert).map_err(|e| {
976 TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
977 })?;
978 }
979
980 debug!(
981 ca_file = %ca_path.display(),
982 cert_count = root_store.len(),
983 "Loaded upstream CA certificates"
984 );
985 } else if !config.insecure_skip_verify {
986 root_store = RootCertStore {
988 roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
989 };
990 trace!("Using webpki-roots for upstream TLS verification");
991 }
992
993 let builder = ClientConfig::builder().with_root_certificates(root_store);
995
996 let client_config = if let (Some(cert_path), Some(key_path)) =
997 (&config.client_cert, &config.client_key)
998 {
999 let cert_file = File::open(cert_path)
1001 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1002 let mut cert_reader = BufReader::new(cert_file);
1003
1004 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
1005 .collect::<Result<Vec<_>, _>>()
1006 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1007
1008 if certs.is_empty() {
1009 return Err(TlsError::CertificateLoad(format!(
1010 "{}: No certificates found",
1011 cert_path.display()
1012 )));
1013 }
1014
1015 let key_file = File::open(key_path)
1017 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
1018 let mut key_reader = BufReader::new(key_file);
1019
1020 let key = rustls_pemfile::private_key(&mut key_reader)
1021 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
1022 .ok_or_else(|| {
1023 TlsError::KeyLoad(format!("{}: No private key found", key_path.display()))
1024 })?;
1025
1026 info!(
1027 cert_file = %cert_path.display(),
1028 "Configured mTLS client certificate for upstream connections"
1029 );
1030
1031 builder
1032 .with_client_auth_cert(certs, key)
1033 .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to set client auth: {}", e)))?
1034 } else {
1035 builder.with_no_client_auth()
1037 };
1038
1039 debug!("Upstream TLS configuration built successfully");
1040 Ok(client_config)
1041}
1042
1043pub fn validate_upstream_tls_config(config: &UpstreamTlsConfig) -> Result<(), TlsError> {
1045 if let Some(ca_path) = &config.ca_cert {
1047 if !ca_path.exists() {
1048 return Err(TlsError::CertificateLoad(format!(
1049 "Upstream CA certificate not found: {}",
1050 ca_path.display()
1051 )));
1052 }
1053 }
1054
1055 if let Some(cert_path) = &config.client_cert {
1057 if !cert_path.exists() {
1058 return Err(TlsError::CertificateLoad(format!(
1059 "Upstream client certificate not found: {}",
1060 cert_path.display()
1061 )));
1062 }
1063
1064 match &config.client_key {
1066 Some(key_path) if !key_path.exists() => {
1067 return Err(TlsError::KeyLoad(format!(
1068 "Upstream client key not found: {}",
1069 key_path.display()
1070 )));
1071 }
1072 None => {
1073 return Err(TlsError::ConfigBuild(
1074 "client_cert specified without client_key".to_string(),
1075 ));
1076 }
1077 _ => {}
1078 }
1079 }
1080
1081 if config.client_key.is_some() && config.client_cert.is_none() {
1082 return Err(TlsError::ConfigBuild(
1083 "client_key specified without client_cert".to_string(),
1084 ));
1085 }
1086
1087 Ok(())
1088}
1089
1090fn load_certified_key(cert_path: &Path, key_path: &Path) -> Result<CertifiedKey, TlsError> {
1096 let cert_file = File::open(cert_path)
1098 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1099 let mut cert_reader = BufReader::new(cert_file);
1100
1101 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
1102 .collect::<Result<Vec<_>, _>>()
1103 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", cert_path.display(), e)))?;
1104
1105 if certs.is_empty() {
1106 return Err(TlsError::CertificateLoad(format!(
1107 "{}: No certificates found in file",
1108 cert_path.display()
1109 )));
1110 }
1111
1112 let key_file = File::open(key_path)
1114 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?;
1115 let mut key_reader = BufReader::new(key_file);
1116
1117 let key = rustls_pemfile::private_key(&mut key_reader)
1118 .map_err(|e| TlsError::KeyLoad(format!("{}: {}", key_path.display(), e)))?
1119 .ok_or_else(|| {
1120 TlsError::KeyLoad(format!(
1121 "{}: No private key found in file",
1122 key_path.display()
1123 ))
1124 })?;
1125
1126 let provider = rustls::crypto::CryptoProvider::get_default()
1128 .cloned()
1129 .unwrap_or_else(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider()));
1130
1131 let signing_key = provider
1132 .key_provider
1133 .load_private_key(key)
1134 .map_err(|e| TlsError::CertKeyMismatch(format!("Failed to load private key: {:?}", e)))?;
1135
1136 Ok(CertifiedKey::new(certs, signing_key))
1137}
1138
1139pub fn load_client_ca(ca_path: &Path) -> Result<RootCertStore, TlsError> {
1141 let ca_file = File::open(ca_path)
1142 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
1143 let mut ca_reader = BufReader::new(ca_file);
1144
1145 let mut root_store = RootCertStore::empty();
1146
1147 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
1148 .collect::<Result<Vec<_>, _>>()
1149 .map_err(|e| TlsError::CertificateLoad(format!("{}: {}", ca_path.display(), e)))?;
1150
1151 for cert in certs {
1152 root_store.add(cert).map_err(|e| {
1153 TlsError::InvalidCertificate(format!("Failed to add CA certificate: {}", e))
1154 })?;
1155 }
1156
1157 if root_store.is_empty() {
1158 return Err(TlsError::CertificateLoad(format!(
1159 "{}: No CA certificates found",
1160 ca_path.display()
1161 )));
1162 }
1163
1164 info!(
1165 ca_file = %ca_path.display(),
1166 cert_count = root_store.len(),
1167 "Loaded client CA certificates"
1168 );
1169
1170 Ok(root_store)
1171}
1172
1173fn resolve_protocol_versions(config: &TlsConfig) -> Vec<&'static rustls::SupportedProtocolVersion> {
1175 use grapsus_common::types::TlsVersion;
1176
1177 let min = &config.min_version;
1178 let max = config.max_version.as_ref().unwrap_or(&TlsVersion::Tls13);
1179
1180 let mut versions = Vec::new();
1181
1182 if matches!(min, TlsVersion::Tls12) {
1184 versions.push(&rustls::version::TLS12);
1185 }
1186
1187 if matches!(max, TlsVersion::Tls13) {
1189 versions.push(&rustls::version::TLS13);
1190 }
1191
1192 if versions.is_empty() {
1193 warn!("No valid TLS versions resolved from config, falling back to TLS 1.2 + 1.3");
1195 versions.push(&rustls::version::TLS12);
1196 versions.push(&rustls::version::TLS13);
1197 }
1198
1199 versions
1200}
1201
1202fn resolve_cipher_suites(names: &[String]) -> Result<Vec<rustls::SupportedCipherSuite>, TlsError> {
1206 use rustls::crypto::aws_lc_rs::cipher_suite;
1207
1208 let known: &[(&str, rustls::SupportedCipherSuite)] = &[
1210 (
1212 "TLS_AES_256_GCM_SHA384",
1213 cipher_suite::TLS13_AES_256_GCM_SHA384,
1214 ),
1215 (
1216 "TLS_AES_128_GCM_SHA256",
1217 cipher_suite::TLS13_AES_128_GCM_SHA256,
1218 ),
1219 (
1220 "TLS_CHACHA20_POLY1305_SHA256",
1221 cipher_suite::TLS13_CHACHA20_POLY1305_SHA256,
1222 ),
1223 (
1225 "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
1226 cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
1227 ),
1228 (
1229 "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
1230 cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
1231 ),
1232 (
1233 "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
1234 cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
1235 ),
1236 (
1237 "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
1238 cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
1239 ),
1240 (
1241 "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
1242 cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
1243 ),
1244 (
1245 "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
1246 cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
1247 ),
1248 ];
1249
1250 let mut suites = Vec::with_capacity(names.len());
1251 for name in names {
1252 let normalized = name.to_uppercase().replace('-', "_");
1253 match known.iter().find(|(n, _)| *n == normalized) {
1254 Some((_, suite)) => suites.push(*suite),
1255 None => {
1256 let available: Vec<&str> = known.iter().map(|(n, _)| *n).collect();
1257 return Err(TlsError::ConfigBuild(format!(
1258 "Unknown cipher suite '{}'. Available: {}",
1259 name,
1260 available.join(", ")
1261 )));
1262 }
1263 }
1264 }
1265
1266 Ok(suites)
1267}
1268
1269pub fn build_server_config(config: &TlsConfig) -> Result<ServerConfig, TlsError> {
1283 let resolver = SniResolver::from_config(config)?;
1284
1285 let versions = resolve_protocol_versions(config);
1287 info!(
1288 versions = ?versions.iter().map(|v| format!("{:?}", v.version)).collect::<Vec<_>>(),
1289 "TLS protocol versions configured"
1290 );
1291
1292 let builder = if !config.cipher_suites.is_empty() {
1294 let suites = resolve_cipher_suites(&config.cipher_suites)?;
1295 info!(
1296 cipher_suites = ?config.cipher_suites,
1297 count = suites.len(),
1298 "Custom TLS cipher suites configured"
1299 );
1300 let provider = rustls::crypto::CryptoProvider {
1301 cipher_suites: suites,
1302 ..rustls::crypto::aws_lc_rs::default_provider()
1303 };
1304 ServerConfig::builder_with_provider(Arc::new(provider))
1305 .with_protocol_versions(&versions)
1306 .map_err(|e| {
1307 TlsError::ConfigBuild(format!("Invalid TLS protocol/cipher configuration: {}", e))
1308 })?
1309 } else {
1310 ServerConfig::builder_with_protocol_versions(&versions)
1311 };
1312
1313 let server_config = if config.client_auth {
1315 if let Some(ca_path) = &config.ca_file {
1316 let root_store = load_client_ca(ca_path)?;
1317 let verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
1318 .build()
1319 .map_err(|e| {
1320 TlsError::ConfigBuild(format!("Failed to build client verifier: {}", e))
1321 })?;
1322
1323 info!("mTLS enabled: client certificates required");
1324
1325 builder
1326 .with_client_cert_verifier(verifier)
1327 .with_cert_resolver(Arc::new(resolver))
1328 } else {
1329 warn!("client_auth enabled but no ca_file specified, disabling client auth");
1330 builder
1331 .with_no_client_auth()
1332 .with_cert_resolver(Arc::new(resolver))
1333 }
1334 } else {
1335 builder
1336 .with_no_client_auth()
1337 .with_cert_resolver(Arc::new(resolver))
1338 };
1339
1340 let mut server_config = server_config;
1342 server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
1343
1344 if !config.session_resumption {
1346 server_config.session_storage = Arc::new(rustls::server::NoServerSessionStorage {});
1347 info!("TLS session resumption disabled");
1348 }
1349
1350 debug!("TLS configuration built successfully");
1351
1352 Ok(server_config)
1353}
1354
1355pub fn validate_tls_config(config: &TlsConfig) -> Result<(), TlsError> {
1357 if config.acme.is_some() {
1359 trace!("Skipping manual cert validation for ACME-managed TLS");
1361 } else {
1362 match (&config.cert_file, &config.key_file) {
1364 (Some(cert_file), Some(key_file)) => {
1365 if !cert_file.exists() {
1366 return Err(TlsError::CertificateLoad(format!(
1367 "Certificate file not found: {}",
1368 cert_file.display()
1369 )));
1370 }
1371 if !key_file.exists() {
1372 return Err(TlsError::KeyLoad(format!(
1373 "Key file not found: {}",
1374 key_file.display()
1375 )));
1376 }
1377 }
1378 _ => {
1379 return Err(TlsError::ConfigBuild(
1380 "TLS configuration requires cert_file and key_file (or ACME block)".to_string(),
1381 ));
1382 }
1383 }
1384 }
1385
1386 for sni in &config.additional_certs {
1388 if !sni.cert_file.exists() {
1389 return Err(TlsError::CertificateLoad(format!(
1390 "SNI certificate file not found: {}",
1391 sni.cert_file.display()
1392 )));
1393 }
1394 if !sni.key_file.exists() {
1395 return Err(TlsError::KeyLoad(format!(
1396 "SNI key file not found: {}",
1397 sni.key_file.display()
1398 )));
1399 }
1400 }
1401
1402 if config.client_auth {
1404 if let Some(ca_path) = &config.ca_file {
1405 if !ca_path.exists() {
1406 return Err(TlsError::CertificateLoad(format!(
1407 "CA certificate file not found: {}",
1408 ca_path.display()
1409 )));
1410 }
1411 }
1412 }
1413
1414 Ok(())
1415}
1416
1417#[cfg(test)]
1418mod tests {
1419
1420 #[test]
1421 fn test_wildcard_matching() {
1422 let name = "foo.bar.example.com";
1425 let parts: Vec<&str> = name.split('.').collect();
1426
1427 assert_eq!(parts.len(), 4);
1428
1429 let domain1 = parts[1..].join(".");
1431 assert_eq!(domain1, "bar.example.com");
1432
1433 let domain2 = parts[2..].join(".");
1434 assert_eq!(domain2, "example.com");
1435 }
1436
1437 #[test]
1438 fn test_hostname_normalization() {
1439 let hostname = "Example.COM";
1440 let normalized = hostname.to_lowercase();
1441 assert_eq!(normalized, "example.com");
1442 }
1443}