1use std::collections::HashMap;
34use std::fs;
35use std::io::BufReader;
36use std::path::Path;
37use std::sync::Arc;
38use std::time::{Duration, SystemTime};
39
40use parking_lot::RwLock;
41use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
42use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
43use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
44use rustls::{
45 ClientConfig, DigitallySignedStruct, DistinguishedName, RootCertStore, ServerConfig,
46 SignatureScheme,
47};
48use tokio_rustls::{TlsAcceptor, TlsConnector};
49use tracing::{debug, error, info, warn};
50use x509_parser::prelude::*;
51
52use crate::error::{NetError, NetResult};
53use crate::tls::{CertificateInfo, CertificateLoader, CertificateStore, HotReloadableCertificates};
54
55#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct Principal {
58 pub name: String,
60 pub organization: Option<String>,
62 pub organizational_unit: Option<String>,
64 pub email: Option<String>,
66 pub serial: String,
68 pub fingerprint: String,
70 pub attributes: HashMap<String, String>,
72}
73
74impl Principal {
75 pub fn from_certificate(cert: &CertificateDer<'_>) -> NetResult<Self> {
77 let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
78 NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
79 })?;
80
81 let name = parsed
82 .subject()
83 .iter_common_name()
84 .next()
85 .and_then(|cn| cn.as_str().ok())
86 .map(String::from)
87 .unwrap_or_else(|| "unknown".to_string());
88
89 let organization = parsed
90 .subject()
91 .iter_organization()
92 .next()
93 .and_then(|o| o.as_str().ok())
94 .map(String::from);
95
96 let organizational_unit = parsed
97 .subject()
98 .iter_organizational_unit()
99 .next()
100 .and_then(|ou| ou.as_str().ok())
101 .map(String::from);
102
103 let mut email = None;
104 if let Ok(Some(san)) = parsed.subject_alternative_name() {
105 for name in san.value.general_names.iter() {
106 if let GeneralName::RFC822Name(e) = name {
107 email = Some(e.to_string());
108 break;
109 }
110 }
111 }
112
113 let serial = format!("{:x}", parsed.serial);
114 use std::fmt::Write;
116 let fingerprint = cert
117 .as_ref()
118 .iter()
119 .take(32)
120 .fold(String::new(), |mut s, b| {
121 let _ = write!(&mut s, "{b:02x}");
122 s
123 });
124
125 Ok(Self {
126 name,
127 organization,
128 organizational_unit,
129 email,
130 serial,
131 fingerprint,
132 attributes: HashMap::new(),
133 })
134 }
135
136 pub fn with_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
138 self.attributes.insert(key.into(), value.into());
139 self
140 }
141}
142
143pub trait PrincipalMapper: Send + Sync {
145 fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal>;
147
148 fn get_principal_name(&self, principal: &Principal) -> String;
150}
151
152#[derive(Debug, Clone, Default)]
154pub struct DefaultPrincipalMapper;
155
156impl PrincipalMapper for DefaultPrincipalMapper {
157 fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
158 Principal::from_certificate(cert)
159 }
160
161 fn get_principal_name(&self, principal: &Principal) -> String {
162 principal.name.clone()
163 }
164}
165
166#[derive(Debug, Clone, Default)]
168pub struct OrganizationPrincipalMapper;
169
170impl PrincipalMapper for OrganizationPrincipalMapper {
171 fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
172 Principal::from_certificate(cert)
173 }
174
175 fn get_principal_name(&self, principal: &Principal) -> String {
176 match &principal.organization {
177 Some(org) => format!("{}/{}", org, principal.name),
178 None => principal.name.clone(),
179 }
180 }
181}
182
183#[derive(Debug, Clone, Copy, PartialEq, Eq)]
185pub enum RevocationStatus {
186 Good,
188 Revoked,
190 Unknown,
192 CheckFailed,
194}
195
196pub trait RevocationChecker: Send + Sync {
198 fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus>;
200
201 fn check_revocation_async(
203 &self,
204 cert: &CertificateDer<'_>,
205 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>;
206}
207
208#[derive(Debug)]
210pub struct CrlRevocationChecker {
211 revoked_serials: Arc<RwLock<HashMap<String, SystemTime>>>,
213 last_update: Arc<RwLock<Option<SystemTime>>>,
215 crl_url: Option<String>,
217}
218
219impl Default for CrlRevocationChecker {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225impl CrlRevocationChecker {
226 pub fn new() -> Self {
228 Self {
229 revoked_serials: Arc::new(RwLock::new(HashMap::new())),
230 last_update: Arc::new(RwLock::new(None)),
231 crl_url: None,
232 }
233 }
234
235 pub fn with_crl_url(mut self, url: impl Into<String>) -> Self {
237 self.crl_url = Some(url.into());
238 self
239 }
240
241 pub fn load_crl_der<P: AsRef<Path>>(&self, path: P) -> NetResult<usize> {
243 let data = fs::read(path.as_ref())
244 .map_err(|e| NetError::InvalidCertificate(format!("Failed to read CRL file: {e}")))?;
245
246 self.load_crl_bytes(&data)
247 }
248
249 pub fn load_crl_pem<P: AsRef<Path>>(&self, path: P) -> NetResult<usize> {
251 let file = fs::File::open(path.as_ref())
252 .map_err(|e| NetError::InvalidCertificate(format!("Failed to open CRL file: {e}")))?;
253 let mut reader = BufReader::new(file);
254
255 let crls: Vec<_> = rustls_pemfile::crls(&mut reader)
256 .filter_map(|r| r.ok())
257 .collect();
258
259 if crls.is_empty() {
260 return Err(NetError::InvalidCertificate(
261 "No CRLs found in PEM file".to_string(),
262 ));
263 }
264
265 let mut total = 0;
266 for crl in crls {
267 total += self.load_crl_bytes(crl.as_ref())?;
268 }
269
270 Ok(total)
271 }
272
273 pub fn load_crl_bytes(&self, crl_data: &[u8]) -> NetResult<usize> {
275 let (_, crl) = CertificateRevocationList::from_der(crl_data)
276 .map_err(|e| NetError::InvalidCertificate(format!("Failed to parse CRL: {e}")))?;
277
278 let mut revoked = self.revoked_serials.write();
279 let mut count = 0;
280
281 for entry in crl.iter_revoked_certificates() {
282 let serial = format!("{:x}", entry.user_certificate);
283 let revocation_time = SystemTime::UNIX_EPOCH; revoked.insert(serial, revocation_time);
285 count += 1;
286 }
287
288 {
289 let mut last = self.last_update.write();
290 *last = Some(SystemTime::now());
291 }
292
293 info!(count = count, "Loaded CRL entries");
294 Ok(count)
295 }
296
297 pub fn add_revoked(&self, serial: impl Into<String>) {
299 let mut revoked = self.revoked_serials.write();
300 revoked.insert(serial.into(), SystemTime::now());
301 }
302
303 pub fn is_revoked(&self, serial: &str) -> bool {
305 self.revoked_serials.read().contains_key(serial)
306 }
307
308 pub fn get_revocation_time(&self, serial: &str) -> Option<SystemTime> {
310 self.revoked_serials.read().get(serial).copied()
311 }
312
313 pub fn revoked_count(&self) -> usize {
315 self.revoked_serials.read().len()
316 }
317
318 pub fn clear(&self) {
320 self.revoked_serials.write().clear();
321 *self.last_update.write() = None;
322 }
323}
324
325impl RevocationChecker for CrlRevocationChecker {
326 fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus> {
327 let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
328 NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
329 })?;
330
331 let serial = format!("{:x}", parsed.serial);
332
333 if self.is_revoked(&serial) {
334 Ok(RevocationStatus::Revoked)
335 } else if self.last_update.read().is_some() {
336 Ok(RevocationStatus::Good)
337 } else {
338 Ok(RevocationStatus::Unknown)
339 }
340 }
341
342 fn check_revocation_async(
343 &self,
344 cert: &CertificateDer<'_>,
345 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>
346 {
347 let result = self.check_revocation(cert);
348 Box::pin(async move { result })
349 }
350}
351
352#[derive(Debug, Default)]
354pub struct OcspRevocationChecker {
355 responder_url: Option<String>,
357 response_cache: Arc<RwLock<HashMap<String, (RevocationStatus, SystemTime)>>>,
359 cache_ttl: Duration,
361}
362
363impl OcspRevocationChecker {
364 pub fn new() -> Self {
366 Self {
367 responder_url: None,
368 response_cache: Arc::new(RwLock::new(HashMap::new())),
369 cache_ttl: Duration::from_secs(3600), }
371 }
372
373 pub fn with_responder_url(mut self, url: impl Into<String>) -> Self {
375 self.responder_url = Some(url.into());
376 self
377 }
378
379 pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
381 self.cache_ttl = ttl;
382 self
383 }
384
385 fn get_cached(&self, fingerprint: &str) -> Option<RevocationStatus> {
387 let cache = self.response_cache.read();
388 if let Some((status, timestamp)) = cache.get(fingerprint) {
389 if timestamp.elapsed().unwrap_or(Duration::MAX) < self.cache_ttl {
390 return Some(*status);
391 }
392 }
393 None
394 }
395
396 fn cache_status(&self, fingerprint: String, status: RevocationStatus) {
398 let mut cache = self.response_cache.write();
399 cache.insert(fingerprint, (status, SystemTime::now()));
400 }
401}
402
403impl RevocationChecker for OcspRevocationChecker {
404 fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus> {
405 use std::fmt::Write;
407 let fingerprint = cert
408 .as_ref()
409 .iter()
410 .take(32)
411 .fold(String::new(), |mut s, b| {
412 let _ = write!(&mut s, "{b:02x}");
413 s
414 });
415
416 if let Some(status) = self.get_cached(&fingerprint) {
418 return Ok(status);
419 }
420
421 warn!("OCSP checking requires async network request, returning Unknown");
424 Ok(RevocationStatus::Unknown)
425 }
426
427 fn check_revocation_async(
428 &self,
429 cert: &CertificateDer<'_>,
430 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>
431 {
432 use std::fmt::Write;
434 let fingerprint = cert
435 .as_ref()
436 .iter()
437 .take(32)
438 .fold(String::new(), |mut s, b| {
439 let _ = write!(&mut s, "{b:02x}");
440 s
441 });
442
443 if let Some(status) = self.get_cached(&fingerprint) {
445 return Box::pin(async move { Ok(status) });
446 }
447
448 let cache_fn = {
451 let fingerprint_clone = fingerprint.clone();
452 let checker = self;
453 move |status: RevocationStatus| {
454 checker.cache_status(fingerprint_clone, status);
455 }
456 };
457
458 Box::pin(async move {
459 warn!("OCSP async check not fully implemented, returning Unknown");
466 let status = RevocationStatus::Unknown;
467 cache_fn(status);
468 Ok(status)
469 })
470 }
471}
472
473#[derive(Debug)]
475pub struct CombinedRevocationChecker {
476 crl: Arc<CrlRevocationChecker>,
478 ocsp: Arc<OcspRevocationChecker>,
480 prefer_ocsp: bool,
482}
483
484impl CombinedRevocationChecker {
485 pub fn new(crl: Arc<CrlRevocationChecker>, ocsp: Arc<OcspRevocationChecker>) -> Self {
487 Self {
488 crl,
489 ocsp,
490 prefer_ocsp: false,
491 }
492 }
493
494 pub fn prefer_ocsp(mut self) -> Self {
496 self.prefer_ocsp = true;
497 self
498 }
499}
500
501impl RevocationChecker for CombinedRevocationChecker {
502 fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus> {
503 if self.prefer_ocsp {
504 match self.ocsp.check_revocation(cert)? {
506 RevocationStatus::Unknown | RevocationStatus::CheckFailed => {
507 self.crl.check_revocation(cert)
509 }
510 status => Ok(status),
511 }
512 } else {
513 match self.crl.check_revocation(cert)? {
515 RevocationStatus::Unknown | RevocationStatus::CheckFailed => {
516 self.ocsp.check_revocation(cert)
518 }
519 status => Ok(status),
520 }
521 }
522 }
523
524 fn check_revocation_async(
525 &self,
526 cert: &CertificateDer<'_>,
527 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>
528 {
529 let result = self.check_revocation(cert);
530 Box::pin(async move { result })
531 }
532}
533
534pub struct MtlsClientVerifier {
536 roots: Arc<RootCertStore>,
538 mapper: Arc<dyn PrincipalMapper>,
540 revocation: Option<Arc<dyn RevocationChecker>>,
542 require_client_auth: bool,
544 allowed_principals: Vec<String>,
546}
547
548impl std::fmt::Debug for MtlsClientVerifier {
549 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
550 f.debug_struct("MtlsClientVerifier")
551 .field("roots", &"<RootCertStore>")
552 .field("mapper", &"<PrincipalMapper>")
553 .field(
554 "revocation",
555 &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
556 )
557 .field("require_client_auth", &self.require_client_auth)
558 .field("allowed_principals", &self.allowed_principals)
559 .finish()
560 }
561}
562
563impl MtlsClientVerifier {
564 pub fn new(roots: RootCertStore) -> Self {
566 Self {
567 roots: Arc::new(roots),
568 mapper: Arc::new(DefaultPrincipalMapper),
569 revocation: None,
570 require_client_auth: true,
571 allowed_principals: Vec::new(),
572 }
573 }
574
575 pub fn with_mapper(mut self, mapper: Arc<dyn PrincipalMapper>) -> Self {
577 self.mapper = mapper;
578 self
579 }
580
581 pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
583 self.revocation = Some(checker);
584 self
585 }
586
587 pub fn optional_auth(mut self) -> Self {
589 self.require_client_auth = false;
590 self
591 }
592
593 pub fn allow_principal(mut self, pattern: impl Into<String>) -> Self {
595 self.allowed_principals.push(pattern.into());
596 self
597 }
598
599 fn verify_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
601 let loader = CertificateLoader::new();
603 let info = loader.get_certificate_info(cert)?;
604
605 if !info.is_valid() {
607 return Err(NetError::InvalidCertificate(
608 "Certificate has expired or is not yet valid".to_string(),
609 ));
610 }
611
612 if let Some(ref checker) = self.revocation {
614 match checker.check_revocation(cert)? {
615 RevocationStatus::Revoked => {
616 return Err(NetError::InvalidCertificate(
617 "Certificate has been revoked".to_string(),
618 ));
619 }
620 RevocationStatus::CheckFailed => {
621 warn!("Revocation check failed, allowing certificate");
622 }
623 _ => {}
624 }
625 }
626
627 let principal = self.mapper.map_certificate(cert)?;
629
630 if !self.allowed_principals.is_empty() {
632 let principal_name = self.mapper.get_principal_name(&principal);
633 let is_allowed = self.allowed_principals.iter().any(|pattern| {
634 if pattern.contains('*') {
635 let regex_pattern = pattern.replace('*', ".*");
637 regex_pattern == principal_name
638 || principal_name.starts_with(&pattern.replace('*', ""))
639 } else {
640 pattern == &principal_name
641 }
642 });
643
644 if !is_allowed {
645 return Err(NetError::InsufficientPermissions(format!(
646 "Principal '{}' is not in the allowed list",
647 principal_name
648 )));
649 }
650 }
651
652 Ok(principal)
653 }
654}
655
656impl ClientCertVerifier for MtlsClientVerifier {
657 fn root_hint_subjects(&self) -> &[DistinguishedName] {
658 &[]
659 }
660
661 fn verify_client_cert(
662 &self,
663 end_entity: &CertificateDer<'_>,
664 _intermediates: &[CertificateDer<'_>],
665 _now: UnixTime,
666 ) -> Result<ClientCertVerified, rustls::Error> {
667 match self.verify_certificate(end_entity) {
668 Ok(principal) => {
669 debug!(principal = %principal.name, "Client certificate verified");
670 Ok(ClientCertVerified::assertion())
671 }
672 Err(e) => {
673 error!(error = %e, "Client certificate verification failed");
674 Err(rustls::Error::InvalidCertificate(
675 rustls::CertificateError::BadEncoding,
676 ))
677 }
678 }
679 }
680
681 fn verify_tls12_signature(
682 &self,
683 _message: &[u8],
684 _cert: &CertificateDer<'_>,
685 _dss: &DigitallySignedStruct,
686 ) -> Result<HandshakeSignatureValid, rustls::Error> {
687 Ok(HandshakeSignatureValid::assertion())
688 }
689
690 fn verify_tls13_signature(
691 &self,
692 _message: &[u8],
693 _cert: &CertificateDer<'_>,
694 _dss: &DigitallySignedStruct,
695 ) -> Result<HandshakeSignatureValid, rustls::Error> {
696 Ok(HandshakeSignatureValid::assertion())
697 }
698
699 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
700 vec![
701 SignatureScheme::RSA_PKCS1_SHA256,
702 SignatureScheme::RSA_PKCS1_SHA384,
703 SignatureScheme::RSA_PKCS1_SHA512,
704 SignatureScheme::ECDSA_NISTP256_SHA256,
705 SignatureScheme::ECDSA_NISTP384_SHA384,
706 SignatureScheme::ECDSA_NISTP521_SHA512,
707 SignatureScheme::ED25519,
708 ]
709 }
710
711 fn client_auth_mandatory(&self) -> bool {
712 self.require_client_auth
713 }
714}
715
716pub struct MtlsServerVerifier {
718 roots: Arc<RootCertStore>,
720 revocation: Option<Arc<dyn RevocationChecker>>,
722 expected_names: Vec<String>,
724}
725
726impl std::fmt::Debug for MtlsServerVerifier {
727 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
728 f.debug_struct("MtlsServerVerifier")
729 .field("roots", &"<RootCertStore>")
730 .field(
731 "revocation",
732 &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
733 )
734 .field("expected_names", &self.expected_names)
735 .finish()
736 }
737}
738
739impl MtlsServerVerifier {
740 pub fn new(roots: RootCertStore) -> Self {
742 Self {
743 roots: Arc::new(roots),
744 revocation: None,
745 expected_names: Vec::new(),
746 }
747 }
748
749 pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
751 self.revocation = Some(checker);
752 self
753 }
754
755 pub fn expect_name(mut self, name: impl Into<String>) -> Self {
757 self.expected_names.push(name.into());
758 self
759 }
760
761 fn verify_certificate(
763 &self,
764 cert: &CertificateDer<'_>,
765 server_name: Option<&str>,
766 ) -> NetResult<()> {
767 let loader = CertificateLoader::new();
768 let info = loader.get_certificate_info(cert)?;
769
770 if !info.is_valid() {
772 return Err(NetError::InvalidCertificate(
773 "Server certificate has expired or is not yet valid".to_string(),
774 ));
775 }
776
777 if let Some(ref checker) = self.revocation {
779 match checker.check_revocation(cert)? {
780 RevocationStatus::Revoked => {
781 return Err(NetError::InvalidCertificate(
782 "Server certificate has been revoked".to_string(),
783 ));
784 }
785 RevocationStatus::CheckFailed => {
786 warn!("Revocation check failed for server certificate");
787 }
788 _ => {}
789 }
790 }
791
792 if let Some(name) = server_name {
794 let name_matches = info.common_name.as_deref() == Some(name)
795 || info.subject_alt_names.iter().any(|san| san == name);
796
797 if !name_matches && !self.expected_names.is_empty() {
798 let expected_matches = self.expected_names.iter().any(|expected| {
799 info.common_name.as_deref() == Some(expected)
800 || info.subject_alt_names.iter().any(|san| san == expected)
801 });
802
803 if !expected_matches {
804 return Err(NetError::InvalidCertificate(format!(
805 "Server name '{}' does not match certificate",
806 name
807 )));
808 }
809 }
810 }
811
812 Ok(())
813 }
814}
815
816impl ServerCertVerifier for MtlsServerVerifier {
817 fn verify_server_cert(
818 &self,
819 end_entity: &CertificateDer<'_>,
820 _intermediates: &[CertificateDer<'_>],
821 server_name: &ServerName<'_>,
822 _ocsp_response: &[u8],
823 _now: UnixTime,
824 ) -> Result<ServerCertVerified, rustls::Error> {
825 let name_str = match server_name {
826 ServerName::DnsName(name) => Some(name.as_ref().to_string()),
827 ServerName::IpAddress(ip) => Some(format!("{:?}", ip)),
828 _ => None,
829 };
830
831 match self.verify_certificate(end_entity, name_str.as_deref()) {
832 Ok(()) => {
833 debug!("Server certificate verified");
834 Ok(ServerCertVerified::assertion())
835 }
836 Err(e) => {
837 error!(error = %e, "Server certificate verification failed");
838 Err(rustls::Error::InvalidCertificate(
839 rustls::CertificateError::BadEncoding,
840 ))
841 }
842 }
843 }
844
845 fn verify_tls12_signature(
846 &self,
847 _message: &[u8],
848 _cert: &CertificateDer<'_>,
849 _dss: &DigitallySignedStruct,
850 ) -> Result<HandshakeSignatureValid, rustls::Error> {
851 Ok(HandshakeSignatureValid::assertion())
852 }
853
854 fn verify_tls13_signature(
855 &self,
856 _message: &[u8],
857 _cert: &CertificateDer<'_>,
858 _dss: &DigitallySignedStruct,
859 ) -> Result<HandshakeSignatureValid, rustls::Error> {
860 Ok(HandshakeSignatureValid::assertion())
861 }
862
863 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
864 vec![
865 SignatureScheme::RSA_PKCS1_SHA256,
866 SignatureScheme::RSA_PKCS1_SHA384,
867 SignatureScheme::RSA_PKCS1_SHA512,
868 SignatureScheme::ECDSA_NISTP256_SHA256,
869 SignatureScheme::ECDSA_NISTP384_SHA384,
870 SignatureScheme::ECDSA_NISTP521_SHA512,
871 SignatureScheme::ED25519,
872 ]
873 }
874}
875
876pub struct MtlsConfigBuilder {
878 cert_chain: Vec<CertificateDer<'static>>,
880 private_key: Option<PrivateKeyDer<'static>>,
882 client_roots: RootCertStore,
884 server_roots: RootCertStore,
886 require_client_auth: bool,
888 mapper: Arc<dyn PrincipalMapper>,
890 revocation: Option<Arc<dyn RevocationChecker>>,
892 hot_reload: Option<Arc<HotReloadableCertificates>>,
894}
895
896impl std::fmt::Debug for MtlsConfigBuilder {
897 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
898 f.debug_struct("MtlsConfigBuilder")
899 .field("cert_chain", &format!("<{} certs>", self.cert_chain.len()))
900 .field("private_key", &self.private_key.as_ref().map(|_| "<key>"))
901 .field("client_roots", &"<RootCertStore>")
902 .field("server_roots", &"<RootCertStore>")
903 .field("require_client_auth", &self.require_client_auth)
904 .field("mapper", &"<PrincipalMapper>")
905 .field(
906 "revocation",
907 &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
908 )
909 .field(
910 "hot_reload",
911 &self.hot_reload.as_ref().map(|_| "<HotReloadable>"),
912 )
913 .finish()
914 }
915}
916
917impl Default for MtlsConfigBuilder {
918 fn default() -> Self {
919 Self::new()
920 }
921}
922
923impl MtlsConfigBuilder {
924 pub fn new() -> Self {
926 Self {
927 cert_chain: Vec::new(),
928 private_key: None,
929 client_roots: RootCertStore::empty(),
930 server_roots: RootCertStore::empty(),
931 require_client_auth: true,
932 mapper: Arc::new(DefaultPrincipalMapper),
933 revocation: None,
934 hot_reload: None,
935 }
936 }
937
938 pub fn with_identity(
940 mut self,
941 cert_chain: Vec<CertificateDer<'static>>,
942 private_key: PrivateKeyDer<'static>,
943 ) -> Self {
944 self.cert_chain = cert_chain;
945 self.private_key = Some(private_key);
946 self
947 }
948
949 pub fn with_identity_files<P: AsRef<Path>>(
951 mut self,
952 cert_path: P,
953 key_path: P,
954 ) -> NetResult<Self> {
955 let loader = CertificateLoader::new();
956 let key_loader = crate::tls::PrivateKeyLoader::new();
957
958 self.cert_chain = loader.load_pem_file(cert_path)?;
959 self.private_key = Some(key_loader.load_pem_file(key_path)?);
960
961 Ok(self)
962 }
963
964 pub fn with_client_ca(mut self, cert: CertificateDer<'static>) -> NetResult<Self> {
966 self.client_roots
967 .add(cert)
968 .map_err(|e| NetError::InvalidCertificate(format!("Failed to add client CA: {e}")))?;
969 Ok(self)
970 }
971
972 pub fn with_client_ca_store(mut self, store: &CertificateStore) -> Self {
974 let roots = store.get_root_store();
975 self.client_roots.extend(roots.roots.iter().cloned());
976 self
977 }
978
979 pub fn with_server_ca(mut self, cert: CertificateDer<'static>) -> NetResult<Self> {
981 self.server_roots
982 .add(cert)
983 .map_err(|e| NetError::InvalidCertificate(format!("Failed to add server CA: {e}")))?;
984 Ok(self)
985 }
986
987 pub fn with_server_ca_store(mut self, store: &CertificateStore) -> Self {
989 let roots = store.get_root_store();
990 self.server_roots.extend(roots.roots.iter().cloned());
991 self
992 }
993
994 pub fn with_system_roots(mut self) -> Self {
996 self.server_roots
997 .extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
998 self
999 }
1000
1001 pub fn require_client_auth(mut self, required: bool) -> Self {
1003 self.require_client_auth = required;
1004 self
1005 }
1006
1007 pub fn with_mapper(mut self, mapper: Arc<dyn PrincipalMapper>) -> Self {
1009 self.mapper = mapper;
1010 self
1011 }
1012
1013 pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
1015 self.revocation = Some(checker);
1016 self
1017 }
1018
1019 pub fn with_hot_reload(mut self, hot_reload: Arc<HotReloadableCertificates>) -> Self {
1021 self.hot_reload = Some(hot_reload);
1022 self
1023 }
1024
1025 pub fn build_server_config(self) -> NetResult<ServerConfig> {
1027 let private_key = self
1028 .private_key
1029 .ok_or_else(|| NetError::InvalidCertificate("Private key is required".to_string()))?;
1030
1031 if self.cert_chain.is_empty() {
1032 return Err(NetError::InvalidCertificate(
1033 "Certificate chain is required".to_string(),
1034 ));
1035 }
1036
1037 let client_verifier =
1039 Arc::new(MtlsClientVerifier::new(self.client_roots).with_mapper(self.mapper));
1040
1041 let config = if self.require_client_auth {
1042 ServerConfig::builder()
1043 .with_client_cert_verifier(client_verifier)
1044 .with_single_cert(self.cert_chain, private_key)
1045 .map_err(|e| {
1046 NetError::InvalidCertificate(format!("Failed to build server config: {e}"))
1047 })?
1048 } else {
1049 ServerConfig::builder()
1050 .with_no_client_auth()
1051 .with_single_cert(self.cert_chain, private_key)
1052 .map_err(|e| {
1053 NetError::InvalidCertificate(format!("Failed to build server config: {e}"))
1054 })?
1055 };
1056
1057 Ok(config)
1058 }
1059
1060 pub fn build_client_config(self) -> NetResult<ClientConfig> {
1062 let private_key = self.private_key.ok_or_else(|| {
1063 NetError::InvalidCertificate("Private key is required for client mTLS".to_string())
1064 })?;
1065
1066 if self.cert_chain.is_empty() {
1067 return Err(NetError::InvalidCertificate(
1068 "Certificate chain is required for client mTLS".to_string(),
1069 ));
1070 }
1071
1072 let server_verifier = Arc::new(MtlsServerVerifier::new(self.server_roots));
1074
1075 let config = ClientConfig::builder()
1076 .dangerous()
1077 .with_custom_certificate_verifier(server_verifier)
1078 .with_client_auth_cert(self.cert_chain, private_key)
1079 .map_err(|e| {
1080 NetError::InvalidCertificate(format!("Failed to build client config: {e}"))
1081 })?;
1082
1083 Ok(config)
1084 }
1085
1086 pub fn build_acceptor(self) -> NetResult<TlsAcceptor> {
1088 let config = self.build_server_config()?;
1089 Ok(TlsAcceptor::from(Arc::new(config)))
1090 }
1091
1092 pub fn build_connector(self) -> NetResult<TlsConnector> {
1094 let config = self.build_client_config()?;
1095 Ok(TlsConnector::from(Arc::new(config)))
1096 }
1097}
1098
1099pub struct MtlsServer {
1101 acceptor: TlsAcceptor,
1103 hot_reload: Option<Arc<HotReloadableCertificates>>,
1105}
1106
1107impl std::fmt::Debug for MtlsServer {
1108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1109 f.debug_struct("MtlsServer")
1110 .field("has_hot_reload", &self.hot_reload.is_some())
1111 .finish()
1112 }
1113}
1114
1115impl MtlsServer {
1116 pub fn builder() -> MtlsConfigBuilder {
1118 MtlsConfigBuilder::new()
1119 }
1120
1121 pub fn from_config(config: ServerConfig) -> Self {
1123 Self {
1124 acceptor: TlsAcceptor::from(Arc::new(config)),
1125 hot_reload: None,
1126 }
1127 }
1128
1129 pub fn acceptor(&self) -> &TlsAcceptor {
1131 &self.acceptor
1132 }
1133
1134 pub fn with_hot_reload(mut self, hot_reload: Arc<HotReloadableCertificates>) -> Self {
1136 self.hot_reload = Some(hot_reload);
1137 self
1138 }
1139}
1140
1141pub struct MtlsClient {
1143 connector: TlsConnector,
1145}
1146
1147impl std::fmt::Debug for MtlsClient {
1148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1149 f.debug_struct("MtlsClient").finish()
1150 }
1151}
1152
1153impl MtlsClient {
1154 pub fn builder() -> MtlsConfigBuilder {
1156 MtlsConfigBuilder::new()
1157 }
1158
1159 pub fn from_config(config: ClientConfig) -> Self {
1161 Self {
1162 connector: TlsConnector::from(Arc::new(config)),
1163 }
1164 }
1165
1166 pub fn connector(&self) -> &TlsConnector {
1168 &self.connector
1169 }
1170}
1171
1172#[derive(Debug, Clone)]
1174pub struct HandshakeResult {
1175 pub peer_principal: Option<Principal>,
1177 pub tls_version: String,
1179 pub cipher_suite: String,
1181 pub duration: Duration,
1183}
1184
1185impl HandshakeResult {
1186 pub fn is_authenticated(&self) -> bool {
1188 self.peer_principal.is_some()
1189 }
1190
1191 pub fn peer_name(&self) -> Option<&str> {
1193 self.peer_principal.as_ref().map(|p| p.name.as_str())
1194 }
1195}
1196
1197#[cfg(test)]
1198mod tests {
1199 use super::*;
1200 use crate::tls::SelfSignedGenerator;
1201
1202 #[test]
1203 fn test_principal_from_certificate() {
1204 let generator = SelfSignedGenerator::new("test-user").with_organization("Test Org");
1206
1207 let (cert, _) = generator.generate().expect("Should generate certificate");
1208
1209 let principal = Principal::from_certificate(&cert).expect("Should create principal");
1210
1211 assert_eq!(principal.name, "test-user");
1212 assert_eq!(principal.organization.as_deref(), Some("Test Org"));
1213 assert!(!principal.fingerprint.is_empty());
1214 }
1215
1216 #[test]
1217 fn test_default_principal_mapper() {
1218 let generator = SelfSignedGenerator::new("test-user");
1219 let (cert, _) = generator.generate().expect("Should generate certificate");
1220
1221 let mapper = DefaultPrincipalMapper;
1222 let principal = mapper
1223 .map_certificate(&cert)
1224 .expect("Should map certificate");
1225 let name = mapper.get_principal_name(&principal);
1226
1227 assert_eq!(name, "test-user");
1228 }
1229
1230 #[test]
1231 fn test_organization_principal_mapper() {
1232 let generator = SelfSignedGenerator::new("test-user").with_organization("Test Org");
1233
1234 let (cert, _) = generator.generate().expect("Should generate certificate");
1235
1236 let mapper = OrganizationPrincipalMapper;
1237 let principal = mapper
1238 .map_certificate(&cert)
1239 .expect("Should map certificate");
1240 let name = mapper.get_principal_name(&principal);
1241
1242 assert_eq!(name, "Test Org/test-user");
1243 }
1244
1245 #[test]
1246 fn test_crl_revocation_checker() {
1247 let checker = CrlRevocationChecker::new();
1248
1249 checker.add_revoked("abc123");
1251
1252 assert!(checker.is_revoked("abc123"));
1253 assert!(!checker.is_revoked("def456"));
1254 assert_eq!(checker.revoked_count(), 1);
1255 }
1256
1257 #[test]
1258 fn test_mtls_config_builder() {
1259 rustls::crypto::ring::default_provider()
1261 .install_default()
1262 .ok();
1263
1264 let ca_generator = SelfSignedGenerator::new("Test CA")
1266 .as_ca()
1267 .with_validity_days(365);
1268
1269 let (ca_cert, _ca_key) = ca_generator.generate().expect("Should generate CA");
1270
1271 let server_generator = SelfSignedGenerator::new("localhost").with_san("127.0.0.1");
1273
1274 let (server_cert, server_key) = server_generator
1275 .generate()
1276 .expect("Should generate server cert");
1277
1278 let result = MtlsConfigBuilder::new()
1280 .with_identity(vec![server_cert.clone()], server_key.clone_key())
1281 .with_client_ca(ca_cert.clone())
1282 .expect("Should add CA")
1283 .require_client_auth(true)
1284 .build_server_config();
1285
1286 assert!(result.is_ok());
1287 }
1288
1289 #[test]
1290 fn test_mtls_client_verifier() {
1291 let ca_generator = SelfSignedGenerator::new("Test CA").as_ca();
1293
1294 let (ca_cert, _) = ca_generator.generate().expect("Should generate CA");
1295
1296 let client_generator =
1297 SelfSignedGenerator::new("test-client").with_organization("Test Org");
1298
1299 let (client_cert, _) = client_generator
1300 .generate()
1301 .expect("Should generate client cert");
1302
1303 let mut roots = RootCertStore::empty();
1305 roots.add(ca_cert).expect("Should add CA");
1306
1307 let verifier = MtlsClientVerifier::new(roots);
1308
1309 let loader = CertificateLoader::new();
1312 let info = loader
1313 .get_certificate_info(&client_cert)
1314 .expect("Should get info");
1315
1316 assert_eq!(info.common_name.as_deref(), Some("test-client"));
1317 }
1318
1319 #[test]
1320 fn test_ocsp_revocation_checker_cache() {
1321 let checker = OcspRevocationChecker::new().with_cache_ttl(Duration::from_secs(3600));
1322
1323 let generator = SelfSignedGenerator::new("test");
1325 let (cert, _) = generator.generate().expect("Should generate cert");
1326
1327 let status = checker
1329 .check_revocation(&cert)
1330 .expect("Should check revocation");
1331 assert_eq!(status, RevocationStatus::Unknown);
1332 }
1333
1334 #[test]
1335 fn test_combined_revocation_checker() {
1336 let crl = Arc::new(CrlRevocationChecker::new());
1337 let ocsp = Arc::new(OcspRevocationChecker::new());
1338
1339 let combined = CombinedRevocationChecker::new(crl.clone(), ocsp);
1340
1341 let generator = SelfSignedGenerator::new("test");
1342 let (cert, _) = generator.generate().expect("Should generate cert");
1343
1344 let status = combined
1346 .check_revocation(&cert)
1347 .expect("Should check revocation");
1348 assert_eq!(status, RevocationStatus::Unknown);
1349 }
1350
1351 #[test]
1352 fn test_handshake_result() {
1353 let principal = Principal {
1354 name: "test-user".to_string(),
1355 organization: Some("Test Org".to_string()),
1356 organizational_unit: None,
1357 email: None,
1358 serial: "123abc".to_string(),
1359 fingerprint: "abc123".to_string(),
1360 attributes: HashMap::new(),
1361 };
1362
1363 let result = HandshakeResult {
1364 peer_principal: Some(principal),
1365 tls_version: "TLS 1.3".to_string(),
1366 cipher_suite: "TLS_AES_256_GCM_SHA384".to_string(),
1367 duration: Duration::from_millis(50),
1368 };
1369
1370 assert!(result.is_authenticated());
1371 assert_eq!(result.peer_name(), Some("test-user"));
1372 }
1373}