1use std::collections::HashMap;
34use std::fs;
35use std::io::BufReader;
36use std::path::Path;
37use std::sync::Arc;
38use std::time::{Duration, SystemTime};
39
40use parking_lot::RwLock;
41use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
42use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
43use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
44use rustls::{
45 ClientConfig, DigitallySignedStruct, DistinguishedName, RootCertStore, ServerConfig,
46 SignatureScheme,
47};
48use tokio_rustls::{TlsAcceptor, TlsConnector};
49use tracing::{debug, error, info, warn};
50use x509_parser::prelude::*;
51
52use crate::error::{NetError, NetResult};
53use crate::tls::{CertificateInfo, CertificateLoader, CertificateStore, HotReloadableCertificates};
54
55#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct Principal {
58 pub name: String,
60 pub organization: Option<String>,
62 pub organizational_unit: Option<String>,
64 pub email: Option<String>,
66 pub serial: String,
68 pub fingerprint: String,
70 pub attributes: HashMap<String, String>,
72}
73
74impl Principal {
75 pub fn from_certificate(cert: &CertificateDer<'_>) -> NetResult<Self> {
77 let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
78 NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
79 })?;
80
81 let name = parsed
82 .subject()
83 .iter_common_name()
84 .next()
85 .and_then(|cn| cn.as_str().ok())
86 .map(String::from)
87 .unwrap_or_else(|| "unknown".to_string());
88
89 let organization = parsed
90 .subject()
91 .iter_organization()
92 .next()
93 .and_then(|o| o.as_str().ok())
94 .map(String::from);
95
96 let organizational_unit = parsed
97 .subject()
98 .iter_organizational_unit()
99 .next()
100 .and_then(|ou| ou.as_str().ok())
101 .map(String::from);
102
103 let mut email = None;
104 if let Ok(Some(san)) = parsed.subject_alternative_name() {
105 for name in san.value.general_names.iter() {
106 if let GeneralName::RFC822Name(e) = name {
107 email = Some(e.to_string());
108 break;
109 }
110 }
111 }
112
113 let serial = format!("{:x}", parsed.serial);
114 use std::fmt::Write;
116 let fingerprint = cert
117 .as_ref()
118 .iter()
119 .take(32)
120 .fold(String::new(), |mut s, b| {
121 let _ = write!(&mut s, "{b:02x}");
122 s
123 });
124
125 Ok(Self {
126 name,
127 organization,
128 organizational_unit,
129 email,
130 serial,
131 fingerprint,
132 attributes: HashMap::new(),
133 })
134 }
135
136 pub fn with_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
138 self.attributes.insert(key.into(), value.into());
139 self
140 }
141}
142
143pub trait PrincipalMapper: Send + Sync {
145 fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal>;
147
148 fn get_principal_name(&self, principal: &Principal) -> String;
150}
151
152#[derive(Debug, Clone, Default)]
154pub struct DefaultPrincipalMapper;
155
156impl PrincipalMapper for DefaultPrincipalMapper {
157 fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
158 Principal::from_certificate(cert)
159 }
160
161 fn get_principal_name(&self, principal: &Principal) -> String {
162 principal.name.clone()
163 }
164}
165
166#[derive(Debug, Clone, Default)]
168pub struct OrganizationPrincipalMapper;
169
170impl PrincipalMapper for OrganizationPrincipalMapper {
171 fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
172 Principal::from_certificate(cert)
173 }
174
175 fn get_principal_name(&self, principal: &Principal) -> String {
176 match &principal.organization {
177 Some(org) => format!("{}/{}", org, principal.name),
178 None => principal.name.clone(),
179 }
180 }
181}
182
183#[derive(Debug, Clone, Copy, PartialEq, Eq)]
185pub enum RevocationStatus {
186 Good,
188 Revoked,
190 Unknown,
192 CheckFailed,
194}
195
196pub trait RevocationChecker: Send + Sync {
198 fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus>;
200
201 fn check_revocation_async(
203 &self,
204 cert: &CertificateDer<'_>,
205 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>;
206}
207
208#[derive(Debug)]
210pub struct CrlRevocationChecker {
211 revoked_serials: Arc<RwLock<HashMap<String, SystemTime>>>,
213 last_update: Arc<RwLock<Option<SystemTime>>>,
215 crl_url: Option<String>,
217}
218
219impl Default for CrlRevocationChecker {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225impl CrlRevocationChecker {
226 pub fn new() -> Self {
228 Self {
229 revoked_serials: Arc::new(RwLock::new(HashMap::new())),
230 last_update: Arc::new(RwLock::new(None)),
231 crl_url: None,
232 }
233 }
234
235 pub fn with_crl_url(mut self, url: impl Into<String>) -> Self {
237 self.crl_url = Some(url.into());
238 self
239 }
240
241 pub fn load_crl_der<P: AsRef<Path>>(&self, path: P) -> NetResult<usize> {
243 let data = fs::read(path.as_ref())
244 .map_err(|e| NetError::InvalidCertificate(format!("Failed to read CRL file: {e}")))?;
245
246 self.load_crl_bytes(&data)
247 }
248
249 pub fn load_crl_pem<P: AsRef<Path>>(&self, path: P) -> NetResult<usize> {
251 let file = fs::File::open(path.as_ref())
252 .map_err(|e| NetError::InvalidCertificate(format!("Failed to open CRL file: {e}")))?;
253 let mut reader = BufReader::new(file);
254
255 let crls: Vec<_> = rustls_pemfile::crls(&mut reader)
256 .filter_map(|r| r.ok())
257 .collect();
258
259 if crls.is_empty() {
260 return Err(NetError::InvalidCertificate(
261 "No CRLs found in PEM file".to_string(),
262 ));
263 }
264
265 let mut total = 0;
266 for crl in crls {
267 total += self.load_crl_bytes(crl.as_ref())?;
268 }
269
270 Ok(total)
271 }
272
273 pub fn load_crl_bytes(&self, crl_data: &[u8]) -> NetResult<usize> {
275 let (_, crl) = CertificateRevocationList::from_der(crl_data)
276 .map_err(|e| NetError::InvalidCertificate(format!("Failed to parse CRL: {e}")))?;
277
278 let mut revoked = self.revoked_serials.write();
279 let mut count = 0;
280
281 for entry in crl.iter_revoked_certificates() {
282 let serial = format!("{:x}", entry.user_certificate);
283 let revocation_time = SystemTime::UNIX_EPOCH; revoked.insert(serial, revocation_time);
285 count += 1;
286 }
287
288 {
289 let mut last = self.last_update.write();
290 *last = Some(SystemTime::now());
291 }
292
293 info!(count = count, "Loaded CRL entries");
294 Ok(count)
295 }
296
297 pub fn add_revoked(&self, serial: impl Into<String>) {
299 let mut revoked = self.revoked_serials.write();
300 revoked.insert(serial.into(), SystemTime::now());
301 }
302
303 pub fn is_revoked(&self, serial: &str) -> bool {
305 self.revoked_serials.read().contains_key(serial)
306 }
307
308 pub fn get_revocation_time(&self, serial: &str) -> Option<SystemTime> {
310 self.revoked_serials.read().get(serial).copied()
311 }
312
313 pub fn revoked_count(&self) -> usize {
315 self.revoked_serials.read().len()
316 }
317
318 pub fn clear(&self) {
320 self.revoked_serials.write().clear();
321 *self.last_update.write() = None;
322 }
323}
324
325impl RevocationChecker for CrlRevocationChecker {
326 fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus> {
327 let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
328 NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
329 })?;
330
331 let serial = format!("{:x}", parsed.serial);
332
333 if self.is_revoked(&serial) {
334 Ok(RevocationStatus::Revoked)
335 } else if self.last_update.read().is_some() {
336 Ok(RevocationStatus::Good)
337 } else {
338 Ok(RevocationStatus::Unknown)
339 }
340 }
341
342 fn check_revocation_async(
343 &self,
344 cert: &CertificateDer<'_>,
345 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>
346 {
347 let result = self.check_revocation(cert);
348 Box::pin(async move { result })
349 }
350}
351
352pub use crate::ocsp::OcspRevocationChecker;
356
357#[derive(Debug)]
359pub struct CombinedRevocationChecker {
360 crl: Arc<CrlRevocationChecker>,
362 ocsp: Arc<OcspRevocationChecker>,
364 prefer_ocsp: bool,
366}
367
368impl CombinedRevocationChecker {
369 pub fn new(crl: Arc<CrlRevocationChecker>, ocsp: Arc<OcspRevocationChecker>) -> Self {
371 Self {
372 crl,
373 ocsp,
374 prefer_ocsp: false,
375 }
376 }
377
378 pub fn prefer_ocsp(mut self) -> Self {
380 self.prefer_ocsp = true;
381 self
382 }
383}
384
385impl RevocationChecker for CombinedRevocationChecker {
386 fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus> {
387 if self.prefer_ocsp {
388 match self.ocsp.check_revocation(cert)? {
390 RevocationStatus::Unknown | RevocationStatus::CheckFailed => {
391 self.crl.check_revocation(cert)
393 }
394 status => Ok(status),
395 }
396 } else {
397 match self.crl.check_revocation(cert)? {
399 RevocationStatus::Unknown | RevocationStatus::CheckFailed => {
400 self.ocsp.check_revocation(cert)
402 }
403 status => Ok(status),
404 }
405 }
406 }
407
408 fn check_revocation_async(
409 &self,
410 cert: &CertificateDer<'_>,
411 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>
412 {
413 let result = self.check_revocation(cert);
414 Box::pin(async move { result })
415 }
416}
417
418pub struct MtlsClientVerifier {
420 roots: Arc<RootCertStore>,
422 mapper: Arc<dyn PrincipalMapper>,
424 revocation: Option<Arc<dyn RevocationChecker>>,
426 require_client_auth: bool,
428 allowed_principals: Vec<String>,
430}
431
432impl std::fmt::Debug for MtlsClientVerifier {
433 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
434 f.debug_struct("MtlsClientVerifier")
435 .field("roots", &"<RootCertStore>")
436 .field("mapper", &"<PrincipalMapper>")
437 .field(
438 "revocation",
439 &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
440 )
441 .field("require_client_auth", &self.require_client_auth)
442 .field("allowed_principals", &self.allowed_principals)
443 .finish()
444 }
445}
446
447impl MtlsClientVerifier {
448 pub fn new(roots: RootCertStore) -> Self {
450 Self {
451 roots: Arc::new(roots),
452 mapper: Arc::new(DefaultPrincipalMapper),
453 revocation: None,
454 require_client_auth: true,
455 allowed_principals: Vec::new(),
456 }
457 }
458
459 pub fn with_mapper(mut self, mapper: Arc<dyn PrincipalMapper>) -> Self {
461 self.mapper = mapper;
462 self
463 }
464
465 pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
467 self.revocation = Some(checker);
468 self
469 }
470
471 pub fn optional_auth(mut self) -> Self {
473 self.require_client_auth = false;
474 self
475 }
476
477 pub fn allow_principal(mut self, pattern: impl Into<String>) -> Self {
479 self.allowed_principals.push(pattern.into());
480 self
481 }
482
483 fn verify_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
485 let loader = CertificateLoader::new();
487 let info = loader.get_certificate_info(cert)?;
488
489 if !info.is_valid() {
491 return Err(NetError::InvalidCertificate(
492 "Certificate has expired or is not yet valid".to_string(),
493 ));
494 }
495
496 if let Some(ref checker) = self.revocation {
498 match checker.check_revocation(cert)? {
499 RevocationStatus::Revoked => {
500 return Err(NetError::InvalidCertificate(
501 "Certificate has been revoked".to_string(),
502 ));
503 }
504 RevocationStatus::CheckFailed => {
505 warn!("Revocation check failed, allowing certificate");
506 }
507 _ => {}
508 }
509 }
510
511 let principal = self.mapper.map_certificate(cert)?;
513
514 if !self.allowed_principals.is_empty() {
516 let principal_name = self.mapper.get_principal_name(&principal);
517 let is_allowed = self.allowed_principals.iter().any(|pattern| {
518 if pattern.contains('*') {
519 let regex_pattern = pattern.replace('*', ".*");
521 regex_pattern == principal_name
522 || principal_name.starts_with(&pattern.replace('*', ""))
523 } else {
524 pattern == &principal_name
525 }
526 });
527
528 if !is_allowed {
529 return Err(NetError::InsufficientPermissions(format!(
530 "Principal '{}' is not in the allowed list",
531 principal_name
532 )));
533 }
534 }
535
536 Ok(principal)
537 }
538}
539
540impl ClientCertVerifier for MtlsClientVerifier {
541 fn root_hint_subjects(&self) -> &[DistinguishedName] {
542 &[]
543 }
544
545 fn verify_client_cert(
546 &self,
547 end_entity: &CertificateDer<'_>,
548 _intermediates: &[CertificateDer<'_>],
549 _now: UnixTime,
550 ) -> Result<ClientCertVerified, rustls::Error> {
551 match self.verify_certificate(end_entity) {
552 Ok(principal) => {
553 debug!(principal = %principal.name, "Client certificate verified");
554 Ok(ClientCertVerified::assertion())
555 }
556 Err(e) => {
557 error!(error = %e, "Client certificate verification failed");
558 Err(rustls::Error::InvalidCertificate(
559 rustls::CertificateError::BadEncoding,
560 ))
561 }
562 }
563 }
564
565 fn verify_tls12_signature(
566 &self,
567 _message: &[u8],
568 _cert: &CertificateDer<'_>,
569 _dss: &DigitallySignedStruct,
570 ) -> Result<HandshakeSignatureValid, rustls::Error> {
571 Ok(HandshakeSignatureValid::assertion())
572 }
573
574 fn verify_tls13_signature(
575 &self,
576 _message: &[u8],
577 _cert: &CertificateDer<'_>,
578 _dss: &DigitallySignedStruct,
579 ) -> Result<HandshakeSignatureValid, rustls::Error> {
580 Ok(HandshakeSignatureValid::assertion())
581 }
582
583 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
584 vec![
585 SignatureScheme::RSA_PKCS1_SHA256,
586 SignatureScheme::RSA_PKCS1_SHA384,
587 SignatureScheme::RSA_PKCS1_SHA512,
588 SignatureScheme::ECDSA_NISTP256_SHA256,
589 SignatureScheme::ECDSA_NISTP384_SHA384,
590 SignatureScheme::ECDSA_NISTP521_SHA512,
591 SignatureScheme::ED25519,
592 ]
593 }
594
595 fn client_auth_mandatory(&self) -> bool {
596 self.require_client_auth
597 }
598}
599
600pub struct MtlsServerVerifier {
602 roots: Arc<RootCertStore>,
604 revocation: Option<Arc<dyn RevocationChecker>>,
606 expected_names: Vec<String>,
608}
609
610impl std::fmt::Debug for MtlsServerVerifier {
611 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
612 f.debug_struct("MtlsServerVerifier")
613 .field("roots", &"<RootCertStore>")
614 .field(
615 "revocation",
616 &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
617 )
618 .field("expected_names", &self.expected_names)
619 .finish()
620 }
621}
622
623impl MtlsServerVerifier {
624 pub fn new(roots: RootCertStore) -> Self {
626 Self {
627 roots: Arc::new(roots),
628 revocation: None,
629 expected_names: Vec::new(),
630 }
631 }
632
633 pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
635 self.revocation = Some(checker);
636 self
637 }
638
639 pub fn expect_name(mut self, name: impl Into<String>) -> Self {
641 self.expected_names.push(name.into());
642 self
643 }
644
645 fn verify_certificate(
647 &self,
648 cert: &CertificateDer<'_>,
649 server_name: Option<&str>,
650 ) -> NetResult<()> {
651 let loader = CertificateLoader::new();
652 let info = loader.get_certificate_info(cert)?;
653
654 if !info.is_valid() {
656 return Err(NetError::InvalidCertificate(
657 "Server certificate has expired or is not yet valid".to_string(),
658 ));
659 }
660
661 if let Some(ref checker) = self.revocation {
663 match checker.check_revocation(cert)? {
664 RevocationStatus::Revoked => {
665 return Err(NetError::InvalidCertificate(
666 "Server certificate has been revoked".to_string(),
667 ));
668 }
669 RevocationStatus::CheckFailed => {
670 warn!("Revocation check failed for server certificate");
671 }
672 _ => {}
673 }
674 }
675
676 if let Some(name) = server_name {
678 let name_matches = info.common_name.as_deref() == Some(name)
679 || info.subject_alt_names.iter().any(|san| san == name);
680
681 if !name_matches && !self.expected_names.is_empty() {
682 let expected_matches = self.expected_names.iter().any(|expected| {
683 info.common_name.as_deref() == Some(expected)
684 || info.subject_alt_names.iter().any(|san| san == expected)
685 });
686
687 if !expected_matches {
688 return Err(NetError::InvalidCertificate(format!(
689 "Server name '{}' does not match certificate",
690 name
691 )));
692 }
693 }
694 }
695
696 Ok(())
697 }
698}
699
700impl ServerCertVerifier for MtlsServerVerifier {
701 fn verify_server_cert(
702 &self,
703 end_entity: &CertificateDer<'_>,
704 _intermediates: &[CertificateDer<'_>],
705 server_name: &ServerName<'_>,
706 _ocsp_response: &[u8],
707 _now: UnixTime,
708 ) -> Result<ServerCertVerified, rustls::Error> {
709 let name_str = match server_name {
710 ServerName::DnsName(name) => Some(name.as_ref().to_string()),
711 ServerName::IpAddress(ip) => Some(format!("{:?}", ip)),
712 _ => None,
713 };
714
715 match self.verify_certificate(end_entity, name_str.as_deref()) {
716 Ok(()) => {
717 debug!("Server certificate verified");
718 Ok(ServerCertVerified::assertion())
719 }
720 Err(e) => {
721 error!(error = %e, "Server certificate verification failed");
722 Err(rustls::Error::InvalidCertificate(
723 rustls::CertificateError::BadEncoding,
724 ))
725 }
726 }
727 }
728
729 fn verify_tls12_signature(
730 &self,
731 _message: &[u8],
732 _cert: &CertificateDer<'_>,
733 _dss: &DigitallySignedStruct,
734 ) -> Result<HandshakeSignatureValid, rustls::Error> {
735 Ok(HandshakeSignatureValid::assertion())
736 }
737
738 fn verify_tls13_signature(
739 &self,
740 _message: &[u8],
741 _cert: &CertificateDer<'_>,
742 _dss: &DigitallySignedStruct,
743 ) -> Result<HandshakeSignatureValid, rustls::Error> {
744 Ok(HandshakeSignatureValid::assertion())
745 }
746
747 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
748 vec![
749 SignatureScheme::RSA_PKCS1_SHA256,
750 SignatureScheme::RSA_PKCS1_SHA384,
751 SignatureScheme::RSA_PKCS1_SHA512,
752 SignatureScheme::ECDSA_NISTP256_SHA256,
753 SignatureScheme::ECDSA_NISTP384_SHA384,
754 SignatureScheme::ECDSA_NISTP521_SHA512,
755 SignatureScheme::ED25519,
756 ]
757 }
758}
759
760pub struct MtlsConfigBuilder {
762 cert_chain: Vec<CertificateDer<'static>>,
764 private_key: Option<PrivateKeyDer<'static>>,
766 client_roots: RootCertStore,
768 server_roots: RootCertStore,
770 require_client_auth: bool,
772 mapper: Arc<dyn PrincipalMapper>,
774 revocation: Option<Arc<dyn RevocationChecker>>,
776 hot_reload: Option<Arc<HotReloadableCertificates>>,
778}
779
780impl std::fmt::Debug for MtlsConfigBuilder {
781 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
782 f.debug_struct("MtlsConfigBuilder")
783 .field("cert_chain", &format!("<{} certs>", self.cert_chain.len()))
784 .field("private_key", &self.private_key.as_ref().map(|_| "<key>"))
785 .field("client_roots", &"<RootCertStore>")
786 .field("server_roots", &"<RootCertStore>")
787 .field("require_client_auth", &self.require_client_auth)
788 .field("mapper", &"<PrincipalMapper>")
789 .field(
790 "revocation",
791 &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
792 )
793 .field(
794 "hot_reload",
795 &self.hot_reload.as_ref().map(|_| "<HotReloadable>"),
796 )
797 .finish()
798 }
799}
800
801impl Default for MtlsConfigBuilder {
802 fn default() -> Self {
803 Self::new()
804 }
805}
806
807impl MtlsConfigBuilder {
808 pub fn new() -> Self {
810 Self {
811 cert_chain: Vec::new(),
812 private_key: None,
813 client_roots: RootCertStore::empty(),
814 server_roots: RootCertStore::empty(),
815 require_client_auth: true,
816 mapper: Arc::new(DefaultPrincipalMapper),
817 revocation: None,
818 hot_reload: None,
819 }
820 }
821
822 pub fn with_identity(
824 mut self,
825 cert_chain: Vec<CertificateDer<'static>>,
826 private_key: PrivateKeyDer<'static>,
827 ) -> Self {
828 self.cert_chain = cert_chain;
829 self.private_key = Some(private_key);
830 self
831 }
832
833 pub fn with_identity_files<P: AsRef<Path>>(
835 mut self,
836 cert_path: P,
837 key_path: P,
838 ) -> NetResult<Self> {
839 let loader = CertificateLoader::new();
840 let key_loader = crate::tls::PrivateKeyLoader::new();
841
842 self.cert_chain = loader.load_pem_file(cert_path)?;
843 self.private_key = Some(key_loader.load_pem_file(key_path)?);
844
845 Ok(self)
846 }
847
848 pub fn with_client_ca(mut self, cert: CertificateDer<'static>) -> NetResult<Self> {
850 self.client_roots
851 .add(cert)
852 .map_err(|e| NetError::InvalidCertificate(format!("Failed to add client CA: {e}")))?;
853 Ok(self)
854 }
855
856 pub fn with_client_ca_store(mut self, store: &CertificateStore) -> Self {
858 let roots = store.get_root_store();
859 self.client_roots.extend(roots.roots.iter().cloned());
860 self
861 }
862
863 pub fn with_server_ca(mut self, cert: CertificateDer<'static>) -> NetResult<Self> {
865 self.server_roots
866 .add(cert)
867 .map_err(|e| NetError::InvalidCertificate(format!("Failed to add server CA: {e}")))?;
868 Ok(self)
869 }
870
871 pub fn with_server_ca_store(mut self, store: &CertificateStore) -> Self {
873 let roots = store.get_root_store();
874 self.server_roots.extend(roots.roots.iter().cloned());
875 self
876 }
877
878 pub fn with_system_roots(mut self) -> Self {
880 self.server_roots
881 .extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
882 self
883 }
884
885 pub fn require_client_auth(mut self, required: bool) -> Self {
887 self.require_client_auth = required;
888 self
889 }
890
891 pub fn with_mapper(mut self, mapper: Arc<dyn PrincipalMapper>) -> Self {
893 self.mapper = mapper;
894 self
895 }
896
897 pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
899 self.revocation = Some(checker);
900 self
901 }
902
903 pub fn with_hot_reload(mut self, hot_reload: Arc<HotReloadableCertificates>) -> Self {
905 self.hot_reload = Some(hot_reload);
906 self
907 }
908
909 pub fn build_server_config(self) -> NetResult<ServerConfig> {
911 let private_key = self
912 .private_key
913 .ok_or_else(|| NetError::InvalidCertificate("Private key is required".to_string()))?;
914
915 if self.cert_chain.is_empty() {
916 return Err(NetError::InvalidCertificate(
917 "Certificate chain is required".to_string(),
918 ));
919 }
920
921 let client_verifier =
923 Arc::new(MtlsClientVerifier::new(self.client_roots).with_mapper(self.mapper));
924
925 let config = if self.require_client_auth {
926 ServerConfig::builder()
927 .with_client_cert_verifier(client_verifier)
928 .with_single_cert(self.cert_chain, private_key)
929 .map_err(|e| {
930 NetError::InvalidCertificate(format!("Failed to build server config: {e}"))
931 })?
932 } else {
933 ServerConfig::builder()
934 .with_no_client_auth()
935 .with_single_cert(self.cert_chain, private_key)
936 .map_err(|e| {
937 NetError::InvalidCertificate(format!("Failed to build server config: {e}"))
938 })?
939 };
940
941 Ok(config)
942 }
943
944 pub fn build_client_config(self) -> NetResult<ClientConfig> {
946 let private_key = self.private_key.ok_or_else(|| {
947 NetError::InvalidCertificate("Private key is required for client mTLS".to_string())
948 })?;
949
950 if self.cert_chain.is_empty() {
951 return Err(NetError::InvalidCertificate(
952 "Certificate chain is required for client mTLS".to_string(),
953 ));
954 }
955
956 let server_verifier = Arc::new(MtlsServerVerifier::new(self.server_roots));
958
959 let config = ClientConfig::builder()
960 .dangerous()
961 .with_custom_certificate_verifier(server_verifier)
962 .with_client_auth_cert(self.cert_chain, private_key)
963 .map_err(|e| {
964 NetError::InvalidCertificate(format!("Failed to build client config: {e}"))
965 })?;
966
967 Ok(config)
968 }
969
970 pub fn build_acceptor(self) -> NetResult<TlsAcceptor> {
972 let config = self.build_server_config()?;
973 Ok(TlsAcceptor::from(Arc::new(config)))
974 }
975
976 pub fn build_connector(self) -> NetResult<TlsConnector> {
978 let config = self.build_client_config()?;
979 Ok(TlsConnector::from(Arc::new(config)))
980 }
981}
982
983pub struct MtlsServer {
985 acceptor: TlsAcceptor,
987 hot_reload: Option<Arc<HotReloadableCertificates>>,
989}
990
991impl std::fmt::Debug for MtlsServer {
992 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
993 f.debug_struct("MtlsServer")
994 .field("has_hot_reload", &self.hot_reload.is_some())
995 .finish()
996 }
997}
998
999impl MtlsServer {
1000 pub fn builder() -> MtlsConfigBuilder {
1002 MtlsConfigBuilder::new()
1003 }
1004
1005 pub fn from_config(config: ServerConfig) -> Self {
1007 Self {
1008 acceptor: TlsAcceptor::from(Arc::new(config)),
1009 hot_reload: None,
1010 }
1011 }
1012
1013 pub fn acceptor(&self) -> &TlsAcceptor {
1015 &self.acceptor
1016 }
1017
1018 pub fn with_hot_reload(mut self, hot_reload: Arc<HotReloadableCertificates>) -> Self {
1020 self.hot_reload = Some(hot_reload);
1021 self
1022 }
1023}
1024
1025pub struct MtlsClient {
1027 connector: TlsConnector,
1029}
1030
1031impl std::fmt::Debug for MtlsClient {
1032 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1033 f.debug_struct("MtlsClient").finish()
1034 }
1035}
1036
1037impl MtlsClient {
1038 pub fn builder() -> MtlsConfigBuilder {
1040 MtlsConfigBuilder::new()
1041 }
1042
1043 pub fn from_config(config: ClientConfig) -> Self {
1045 Self {
1046 connector: TlsConnector::from(Arc::new(config)),
1047 }
1048 }
1049
1050 pub fn connector(&self) -> &TlsConnector {
1052 &self.connector
1053 }
1054}
1055
1056#[derive(Debug, Clone)]
1058pub struct HandshakeResult {
1059 pub peer_principal: Option<Principal>,
1061 pub tls_version: String,
1063 pub cipher_suite: String,
1065 pub duration: Duration,
1067}
1068
1069impl HandshakeResult {
1070 pub fn is_authenticated(&self) -> bool {
1072 self.peer_principal.is_some()
1073 }
1074
1075 pub fn peer_name(&self) -> Option<&str> {
1077 self.peer_principal.as_ref().map(|p| p.name.as_str())
1078 }
1079}
1080
1081#[cfg(test)]
1082mod tests {
1083 use super::*;
1084 use crate::tls::SelfSignedGenerator;
1085
1086 #[test]
1087 fn test_principal_from_certificate() {
1088 let generator = SelfSignedGenerator::new("test-user").with_organization("Test Org");
1090
1091 let (cert, _) = generator.generate().expect("Should generate certificate");
1092
1093 let principal = Principal::from_certificate(&cert).expect("Should create principal");
1094
1095 assert_eq!(principal.name, "test-user");
1096 assert_eq!(principal.organization.as_deref(), Some("Test Org"));
1097 assert!(!principal.fingerprint.is_empty());
1098 }
1099
1100 #[test]
1101 fn test_default_principal_mapper() {
1102 let generator = SelfSignedGenerator::new("test-user");
1103 let (cert, _) = generator.generate().expect("Should generate certificate");
1104
1105 let mapper = DefaultPrincipalMapper;
1106 let principal = mapper
1107 .map_certificate(&cert)
1108 .expect("Should map certificate");
1109 let name = mapper.get_principal_name(&principal);
1110
1111 assert_eq!(name, "test-user");
1112 }
1113
1114 #[test]
1115 fn test_organization_principal_mapper() {
1116 let generator = SelfSignedGenerator::new("test-user").with_organization("Test Org");
1117
1118 let (cert, _) = generator.generate().expect("Should generate certificate");
1119
1120 let mapper = OrganizationPrincipalMapper;
1121 let principal = mapper
1122 .map_certificate(&cert)
1123 .expect("Should map certificate");
1124 let name = mapper.get_principal_name(&principal);
1125
1126 assert_eq!(name, "Test Org/test-user");
1127 }
1128
1129 #[test]
1130 fn test_crl_revocation_checker() {
1131 let checker = CrlRevocationChecker::new();
1132
1133 checker.add_revoked("abc123");
1135
1136 assert!(checker.is_revoked("abc123"));
1137 assert!(!checker.is_revoked("def456"));
1138 assert_eq!(checker.revoked_count(), 1);
1139 }
1140
1141 #[test]
1142 fn test_mtls_config_builder() {
1143 rustls::crypto::ring::default_provider()
1145 .install_default()
1146 .ok();
1147
1148 let ca_generator = SelfSignedGenerator::new("Test CA")
1150 .as_ca()
1151 .with_validity_days(365);
1152
1153 let (ca_cert, _ca_key) = ca_generator.generate().expect("Should generate CA");
1154
1155 let server_generator = SelfSignedGenerator::new("localhost").with_san("127.0.0.1");
1157
1158 let (server_cert, server_key) = server_generator
1159 .generate()
1160 .expect("Should generate server cert");
1161
1162 let result = MtlsConfigBuilder::new()
1164 .with_identity(vec![server_cert.clone()], server_key.clone_key())
1165 .with_client_ca(ca_cert.clone())
1166 .expect("Should add CA")
1167 .require_client_auth(true)
1168 .build_server_config();
1169
1170 assert!(result.is_ok());
1171 }
1172
1173 #[test]
1174 fn test_mtls_client_verifier() {
1175 let ca_generator = SelfSignedGenerator::new("Test CA").as_ca();
1177
1178 let (ca_cert, _) = ca_generator.generate().expect("Should generate CA");
1179
1180 let client_generator =
1181 SelfSignedGenerator::new("test-client").with_organization("Test Org");
1182
1183 let (client_cert, _) = client_generator
1184 .generate()
1185 .expect("Should generate client cert");
1186
1187 let mut roots = RootCertStore::empty();
1189 roots.add(ca_cert).expect("Should add CA");
1190
1191 let verifier = MtlsClientVerifier::new(roots);
1192
1193 let loader = CertificateLoader::new();
1196 let info = loader
1197 .get_certificate_info(&client_cert)
1198 .expect("Should get info");
1199
1200 assert_eq!(info.common_name.as_deref(), Some("test-client"));
1201 }
1202
1203 #[test]
1204 fn test_ocsp_revocation_checker_cache() {
1205 let checker = OcspRevocationChecker::new().with_cache_ttl(Duration::from_secs(3600));
1206
1207 let generator = SelfSignedGenerator::new("test");
1209 let (cert, _) = generator.generate().expect("Should generate cert");
1210
1211 let status = checker
1213 .check_revocation(&cert)
1214 .expect("Should check revocation");
1215 assert_eq!(status, RevocationStatus::Unknown);
1216 }
1217
1218 #[test]
1219 fn test_combined_revocation_checker() {
1220 let crl = Arc::new(CrlRevocationChecker::new());
1221 let ocsp = Arc::new(OcspRevocationChecker::new());
1222
1223 let combined = CombinedRevocationChecker::new(crl.clone(), ocsp);
1224
1225 let generator = SelfSignedGenerator::new("test");
1226 let (cert, _) = generator.generate().expect("Should generate cert");
1227
1228 let status = combined
1230 .check_revocation(&cert)
1231 .expect("Should check revocation");
1232 assert_eq!(status, RevocationStatus::Unknown);
1233 }
1234
1235 #[test]
1236 fn test_handshake_result() {
1237 let principal = Principal {
1238 name: "test-user".to_string(),
1239 organization: Some("Test Org".to_string()),
1240 organizational_unit: None,
1241 email: None,
1242 serial: "123abc".to_string(),
1243 fingerprint: "abc123".to_string(),
1244 attributes: HashMap::new(),
1245 };
1246
1247 let result = HandshakeResult {
1248 peer_principal: Some(principal),
1249 tls_version: "TLS 1.3".to_string(),
1250 cipher_suite: "TLS_AES_256_GCM_SHA384".to_string(),
1251 duration: Duration::from_millis(50),
1252 };
1253
1254 assert!(result.is_authenticated());
1255 assert_eq!(result.peer_name(), Some("test-user"));
1256 }
1257}