Skip to main content

amaters_net/
mtls.rs

1//! Mutual TLS (mTLS) authentication for AmateRS networking layer
2//!
3//! This module provides comprehensive mTLS support including:
4//! - Client certificate verification
5//! - Server certificate verification
6//! - Mutual authentication handshake
7//! - Certificate-to-principal mapping
8//! - Certificate revocation checking (CRL/OCSP)
9//!
10//! # Example
11//!
12//! ```rust,ignore
13//! use amaters_net::mtls::{MtlsConfig, MtlsServer, MtlsClient};
14//! use amaters_net::tls::{CertificateStore, SelfSignedGenerator};
15//!
16//! // Create mTLS configuration
17//! let mut config = MtlsConfig::new();
18//! config.set_client_auth_required(true);
19//!
20//! // Create server with mTLS
21//! let server = MtlsServer::builder()
22//!     .with_identity(cert_chain, private_key)
23//!     .with_client_ca(ca_cert)
24//!     .build()?;
25//!
26//! // Create client with mTLS
27//! let client = MtlsClient::builder()
28//!     .with_identity(client_cert, client_key)
29//!     .with_server_ca(server_ca)
30//!     .build()?;
31//! ```
32
33use std::collections::HashMap;
34use std::fs;
35use std::io::BufReader;
36use std::path::Path;
37use std::sync::Arc;
38use std::time::{Duration, SystemTime};
39
40use parking_lot::RwLock;
41use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
42use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
43use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
44use rustls::{
45    ClientConfig, DigitallySignedStruct, DistinguishedName, RootCertStore, ServerConfig,
46    SignatureScheme,
47};
48use tokio_rustls::{TlsAcceptor, TlsConnector};
49use tracing::{debug, error, info, warn};
50use x509_parser::prelude::*;
51
52use crate::error::{NetError, NetResult};
53use crate::tls::{CertificateInfo, CertificateLoader, CertificateStore, HotReloadableCertificates};
54
55/// Principal identity extracted from a client certificate
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct Principal {
58    /// Subject common name
59    pub name: String,
60    /// Subject organization
61    pub organization: Option<String>,
62    /// Subject organizational unit
63    pub organizational_unit: Option<String>,
64    /// Email address (from SAN or subject)
65    pub email: Option<String>,
66    /// Certificate serial number
67    pub serial: String,
68    /// SHA-256 fingerprint of the certificate
69    pub fingerprint: String,
70    /// Additional attributes
71    pub attributes: HashMap<String, String>,
72}
73
74impl Principal {
75    /// Create a principal from certificate DER bytes
76    pub fn from_certificate(cert: &CertificateDer<'_>) -> NetResult<Self> {
77        let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
78            NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
79        })?;
80
81        let name = parsed
82            .subject()
83            .iter_common_name()
84            .next()
85            .and_then(|cn| cn.as_str().ok())
86            .map(String::from)
87            .unwrap_or_else(|| "unknown".to_string());
88
89        let organization = parsed
90            .subject()
91            .iter_organization()
92            .next()
93            .and_then(|o| o.as_str().ok())
94            .map(String::from);
95
96        let organizational_unit = parsed
97            .subject()
98            .iter_organizational_unit()
99            .next()
100            .and_then(|ou| ou.as_str().ok())
101            .map(String::from);
102
103        let mut email = None;
104        if let Ok(Some(san)) = parsed.subject_alternative_name() {
105            for name in san.value.general_names.iter() {
106                if let GeneralName::RFC822Name(e) = name {
107                    email = Some(e.to_string());
108                    break;
109                }
110            }
111        }
112
113        let serial = format!("{:x}", parsed.serial);
114        // Create fingerprint from first 32 bytes of certificate
115        use std::fmt::Write;
116        let fingerprint = cert
117            .as_ref()
118            .iter()
119            .take(32)
120            .fold(String::new(), |mut s, b| {
121                let _ = write!(&mut s, "{b:02x}");
122                s
123            });
124
125        Ok(Self {
126            name,
127            organization,
128            organizational_unit,
129            email,
130            serial,
131            fingerprint,
132            attributes: HashMap::new(),
133        })
134    }
135
136    /// Add a custom attribute to the principal
137    pub fn with_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
138        self.attributes.insert(key.into(), value.into());
139        self
140    }
141}
142
143/// Certificate-to-principal mapping strategy
144pub trait PrincipalMapper: Send + Sync {
145    /// Map a certificate to a principal
146    fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal>;
147
148    /// Get the principal name for authorization
149    fn get_principal_name(&self, principal: &Principal) -> String;
150}
151
152/// Default principal mapper using certificate subject CN
153#[derive(Debug, Clone, Default)]
154pub struct DefaultPrincipalMapper;
155
156impl PrincipalMapper for DefaultPrincipalMapper {
157    fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
158        Principal::from_certificate(cert)
159    }
160
161    fn get_principal_name(&self, principal: &Principal) -> String {
162        principal.name.clone()
163    }
164}
165
166/// Principal mapper using organization and CN
167#[derive(Debug, Clone, Default)]
168pub struct OrganizationPrincipalMapper;
169
170impl PrincipalMapper for OrganizationPrincipalMapper {
171    fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
172        Principal::from_certificate(cert)
173    }
174
175    fn get_principal_name(&self, principal: &Principal) -> String {
176        match &principal.organization {
177            Some(org) => format!("{}/{}", org, principal.name),
178            None => principal.name.clone(),
179        }
180    }
181}
182
183/// Certificate revocation status
184#[derive(Debug, Clone, Copy, PartialEq, Eq)]
185pub enum RevocationStatus {
186    /// Certificate is valid (not revoked)
187    Good,
188    /// Certificate has been revoked
189    Revoked,
190    /// Revocation status is unknown
191    Unknown,
192    /// Revocation check failed
193    CheckFailed,
194}
195
196/// Certificate revocation checker
197pub trait RevocationChecker: Send + Sync {
198    /// Check if a certificate has been revoked
199    fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus>;
200
201    /// Check if a certificate has been revoked asynchronously
202    fn check_revocation_async(
203        &self,
204        cert: &CertificateDer<'_>,
205    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>;
206}
207
208/// CRL-based certificate revocation checker
209#[derive(Debug)]
210pub struct CrlRevocationChecker {
211    /// CRL entries (serial number -> revocation time)
212    revoked_serials: Arc<RwLock<HashMap<String, SystemTime>>>,
213    /// Last CRL update time
214    last_update: Arc<RwLock<Option<SystemTime>>>,
215    /// CRL update URL
216    crl_url: Option<String>,
217}
218
219impl Default for CrlRevocationChecker {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225impl CrlRevocationChecker {
226    /// Create a new CRL revocation checker
227    pub fn new() -> Self {
228        Self {
229            revoked_serials: Arc::new(RwLock::new(HashMap::new())),
230            last_update: Arc::new(RwLock::new(None)),
231            crl_url: None,
232        }
233    }
234
235    /// Set the CRL distribution point URL
236    pub fn with_crl_url(mut self, url: impl Into<String>) -> Self {
237        self.crl_url = Some(url.into());
238        self
239    }
240
241    /// Load CRL from DER file
242    pub fn load_crl_der<P: AsRef<Path>>(&self, path: P) -> NetResult<usize> {
243        let data = fs::read(path.as_ref())
244            .map_err(|e| NetError::InvalidCertificate(format!("Failed to read CRL file: {e}")))?;
245
246        self.load_crl_bytes(&data)
247    }
248
249    /// Load CRL from PEM file
250    pub fn load_crl_pem<P: AsRef<Path>>(&self, path: P) -> NetResult<usize> {
251        let file = fs::File::open(path.as_ref())
252            .map_err(|e| NetError::InvalidCertificate(format!("Failed to open CRL file: {e}")))?;
253        let mut reader = BufReader::new(file);
254
255        let crls: Vec<_> = rustls_pemfile::crls(&mut reader)
256            .filter_map(|r| r.ok())
257            .collect();
258
259        if crls.is_empty() {
260            return Err(NetError::InvalidCertificate(
261                "No CRLs found in PEM file".to_string(),
262            ));
263        }
264
265        let mut total = 0;
266        for crl in crls {
267            total += self.load_crl_bytes(crl.as_ref())?;
268        }
269
270        Ok(total)
271    }
272
273    /// Load CRL from bytes
274    pub fn load_crl_bytes(&self, crl_data: &[u8]) -> NetResult<usize> {
275        let (_, crl) = CertificateRevocationList::from_der(crl_data)
276            .map_err(|e| NetError::InvalidCertificate(format!("Failed to parse CRL: {e}")))?;
277
278        let mut revoked = self.revoked_serials.write();
279        let mut count = 0;
280
281        for entry in crl.iter_revoked_certificates() {
282            let serial = format!("{:x}", entry.user_certificate);
283            let revocation_time = SystemTime::UNIX_EPOCH; // Default; proper time parsing would be added
284            revoked.insert(serial, revocation_time);
285            count += 1;
286        }
287
288        {
289            let mut last = self.last_update.write();
290            *last = Some(SystemTime::now());
291        }
292
293        info!(count = count, "Loaded CRL entries");
294        Ok(count)
295    }
296
297    /// Add a revoked certificate by serial number
298    pub fn add_revoked(&self, serial: impl Into<String>) {
299        let mut revoked = self.revoked_serials.write();
300        revoked.insert(serial.into(), SystemTime::now());
301    }
302
303    /// Check if a serial number is revoked
304    pub fn is_revoked(&self, serial: &str) -> bool {
305        self.revoked_serials.read().contains_key(serial)
306    }
307
308    /// Get revocation time for a serial
309    pub fn get_revocation_time(&self, serial: &str) -> Option<SystemTime> {
310        self.revoked_serials.read().get(serial).copied()
311    }
312
313    /// Get count of revoked certificates
314    pub fn revoked_count(&self) -> usize {
315        self.revoked_serials.read().len()
316    }
317
318    /// Clear all revoked entries
319    pub fn clear(&self) {
320        self.revoked_serials.write().clear();
321        *self.last_update.write() = None;
322    }
323}
324
325impl RevocationChecker for CrlRevocationChecker {
326    fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus> {
327        let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
328            NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
329        })?;
330
331        let serial = format!("{:x}", parsed.serial);
332
333        if self.is_revoked(&serial) {
334            Ok(RevocationStatus::Revoked)
335        } else if self.last_update.read().is_some() {
336            Ok(RevocationStatus::Good)
337        } else {
338            Ok(RevocationStatus::Unknown)
339        }
340    }
341
342    fn check_revocation_async(
343        &self,
344        cert: &CertificateDer<'_>,
345    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>
346    {
347        let result = self.check_revocation(cert);
348        Box::pin(async move { result })
349    }
350}
351
352/// OCSP-based certificate revocation checker
353///
354/// Re-exported from the `ocsp` module. See [`crate::ocsp`] for full documentation.
355pub use crate::ocsp::OcspRevocationChecker;
356
357/// Combined revocation checker using both CRL and OCSP
358#[derive(Debug)]
359pub struct CombinedRevocationChecker {
360    /// CRL checker
361    crl: Arc<CrlRevocationChecker>,
362    /// OCSP checker
363    ocsp: Arc<OcspRevocationChecker>,
364    /// Prefer OCSP over CRL
365    prefer_ocsp: bool,
366}
367
368impl CombinedRevocationChecker {
369    /// Create a new combined revocation checker
370    pub fn new(crl: Arc<CrlRevocationChecker>, ocsp: Arc<OcspRevocationChecker>) -> Self {
371        Self {
372            crl,
373            ocsp,
374            prefer_ocsp: false,
375        }
376    }
377
378    /// Prefer OCSP over CRL for revocation checking
379    pub fn prefer_ocsp(mut self) -> Self {
380        self.prefer_ocsp = true;
381        self
382    }
383}
384
385impl RevocationChecker for CombinedRevocationChecker {
386    fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus> {
387        if self.prefer_ocsp {
388            // Try OCSP first
389            match self.ocsp.check_revocation(cert)? {
390                RevocationStatus::Unknown | RevocationStatus::CheckFailed => {
391                    // Fall back to CRL
392                    self.crl.check_revocation(cert)
393                }
394                status => Ok(status),
395            }
396        } else {
397            // Try CRL first
398            match self.crl.check_revocation(cert)? {
399                RevocationStatus::Unknown | RevocationStatus::CheckFailed => {
400                    // Fall back to OCSP
401                    self.ocsp.check_revocation(cert)
402                }
403                status => Ok(status),
404            }
405        }
406    }
407
408    fn check_revocation_async(
409        &self,
410        cert: &CertificateDer<'_>,
411    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>
412    {
413        let result = self.check_revocation(cert);
414        Box::pin(async move { result })
415    }
416}
417
418/// Custom client certificate verifier with revocation checking
419pub struct MtlsClientVerifier {
420    /// Root certificates for verification
421    roots: Arc<RootCertStore>,
422    /// Principal mapper
423    mapper: Arc<dyn PrincipalMapper>,
424    /// Revocation checker
425    revocation: Option<Arc<dyn RevocationChecker>>,
426    /// Whether client authentication is required
427    require_client_auth: bool,
428    /// Allowed principal patterns (empty means allow all)
429    allowed_principals: Vec<String>,
430}
431
432impl std::fmt::Debug for MtlsClientVerifier {
433    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
434        f.debug_struct("MtlsClientVerifier")
435            .field("roots", &"<RootCertStore>")
436            .field("mapper", &"<PrincipalMapper>")
437            .field(
438                "revocation",
439                &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
440            )
441            .field("require_client_auth", &self.require_client_auth)
442            .field("allowed_principals", &self.allowed_principals)
443            .finish()
444    }
445}
446
447impl MtlsClientVerifier {
448    /// Create a new client verifier
449    pub fn new(roots: RootCertStore) -> Self {
450        Self {
451            roots: Arc::new(roots),
452            mapper: Arc::new(DefaultPrincipalMapper),
453            revocation: None,
454            require_client_auth: true,
455            allowed_principals: Vec::new(),
456        }
457    }
458
459    /// Set the principal mapper
460    pub fn with_mapper(mut self, mapper: Arc<dyn PrincipalMapper>) -> Self {
461        self.mapper = mapper;
462        self
463    }
464
465    /// Set the revocation checker
466    pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
467        self.revocation = Some(checker);
468        self
469    }
470
471    /// Make client authentication optional
472    pub fn optional_auth(mut self) -> Self {
473        self.require_client_auth = false;
474        self
475    }
476
477    /// Add allowed principal pattern
478    pub fn allow_principal(mut self, pattern: impl Into<String>) -> Self {
479        self.allowed_principals.push(pattern.into());
480        self
481    }
482
483    /// Verify a client certificate
484    fn verify_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
485        // Parse and validate certificate
486        let loader = CertificateLoader::new();
487        let info = loader.get_certificate_info(cert)?;
488
489        // Check validity
490        if !info.is_valid() {
491            return Err(NetError::InvalidCertificate(
492                "Certificate has expired or is not yet valid".to_string(),
493            ));
494        }
495
496        // Check revocation if checker is configured
497        if let Some(ref checker) = self.revocation {
498            match checker.check_revocation(cert)? {
499                RevocationStatus::Revoked => {
500                    return Err(NetError::InvalidCertificate(
501                        "Certificate has been revoked".to_string(),
502                    ));
503                }
504                RevocationStatus::CheckFailed => {
505                    warn!("Revocation check failed, allowing certificate");
506                }
507                _ => {}
508            }
509        }
510
511        // Map certificate to principal
512        let principal = self.mapper.map_certificate(cert)?;
513
514        // Check allowed principals
515        if !self.allowed_principals.is_empty() {
516            let principal_name = self.mapper.get_principal_name(&principal);
517            let is_allowed = self.allowed_principals.iter().any(|pattern| {
518                if pattern.contains('*') {
519                    // Simple wildcard matching
520                    let regex_pattern = pattern.replace('*', ".*");
521                    regex_pattern == principal_name
522                        || principal_name.starts_with(&pattern.replace('*', ""))
523                } else {
524                    pattern == &principal_name
525                }
526            });
527
528            if !is_allowed {
529                return Err(NetError::InsufficientPermissions(format!(
530                    "Principal '{}' is not in the allowed list",
531                    principal_name
532                )));
533            }
534        }
535
536        Ok(principal)
537    }
538}
539
540impl ClientCertVerifier for MtlsClientVerifier {
541    fn root_hint_subjects(&self) -> &[DistinguishedName] {
542        &[]
543    }
544
545    fn verify_client_cert(
546        &self,
547        end_entity: &CertificateDer<'_>,
548        _intermediates: &[CertificateDer<'_>],
549        _now: UnixTime,
550    ) -> Result<ClientCertVerified, rustls::Error> {
551        match self.verify_certificate(end_entity) {
552            Ok(principal) => {
553                debug!(principal = %principal.name, "Client certificate verified");
554                Ok(ClientCertVerified::assertion())
555            }
556            Err(e) => {
557                error!(error = %e, "Client certificate verification failed");
558                Err(rustls::Error::InvalidCertificate(
559                    rustls::CertificateError::BadEncoding,
560                ))
561            }
562        }
563    }
564
565    fn verify_tls12_signature(
566        &self,
567        _message: &[u8],
568        _cert: &CertificateDer<'_>,
569        _dss: &DigitallySignedStruct,
570    ) -> Result<HandshakeSignatureValid, rustls::Error> {
571        Ok(HandshakeSignatureValid::assertion())
572    }
573
574    fn verify_tls13_signature(
575        &self,
576        _message: &[u8],
577        _cert: &CertificateDer<'_>,
578        _dss: &DigitallySignedStruct,
579    ) -> Result<HandshakeSignatureValid, rustls::Error> {
580        Ok(HandshakeSignatureValid::assertion())
581    }
582
583    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
584        vec![
585            SignatureScheme::RSA_PKCS1_SHA256,
586            SignatureScheme::RSA_PKCS1_SHA384,
587            SignatureScheme::RSA_PKCS1_SHA512,
588            SignatureScheme::ECDSA_NISTP256_SHA256,
589            SignatureScheme::ECDSA_NISTP384_SHA384,
590            SignatureScheme::ECDSA_NISTP521_SHA512,
591            SignatureScheme::ED25519,
592        ]
593    }
594
595    fn client_auth_mandatory(&self) -> bool {
596        self.require_client_auth
597    }
598}
599
600/// Custom server certificate verifier
601pub struct MtlsServerVerifier {
602    /// Root certificates for verification
603    roots: Arc<RootCertStore>,
604    /// Revocation checker
605    revocation: Option<Arc<dyn RevocationChecker>>,
606    /// Expected server names
607    expected_names: Vec<String>,
608}
609
610impl std::fmt::Debug for MtlsServerVerifier {
611    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
612        f.debug_struct("MtlsServerVerifier")
613            .field("roots", &"<RootCertStore>")
614            .field(
615                "revocation",
616                &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
617            )
618            .field("expected_names", &self.expected_names)
619            .finish()
620    }
621}
622
623impl MtlsServerVerifier {
624    /// Create a new server verifier
625    pub fn new(roots: RootCertStore) -> Self {
626        Self {
627            roots: Arc::new(roots),
628            revocation: None,
629            expected_names: Vec::new(),
630        }
631    }
632
633    /// Set the revocation checker
634    pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
635        self.revocation = Some(checker);
636        self
637    }
638
639    /// Add expected server name
640    pub fn expect_name(mut self, name: impl Into<String>) -> Self {
641        self.expected_names.push(name.into());
642        self
643    }
644
645    /// Verify a server certificate
646    fn verify_certificate(
647        &self,
648        cert: &CertificateDer<'_>,
649        server_name: Option<&str>,
650    ) -> NetResult<()> {
651        let loader = CertificateLoader::new();
652        let info = loader.get_certificate_info(cert)?;
653
654        // Check validity
655        if !info.is_valid() {
656            return Err(NetError::InvalidCertificate(
657                "Server certificate has expired or is not yet valid".to_string(),
658            ));
659        }
660
661        // Check revocation if checker is configured
662        if let Some(ref checker) = self.revocation {
663            match checker.check_revocation(cert)? {
664                RevocationStatus::Revoked => {
665                    return Err(NetError::InvalidCertificate(
666                        "Server certificate has been revoked".to_string(),
667                    ));
668                }
669                RevocationStatus::CheckFailed => {
670                    warn!("Revocation check failed for server certificate");
671                }
672                _ => {}
673            }
674        }
675
676        // Verify server name if specified
677        if let Some(name) = server_name {
678            let name_matches = info.common_name.as_deref() == Some(name)
679                || info.subject_alt_names.iter().any(|san| san == name);
680
681            if !name_matches && !self.expected_names.is_empty() {
682                let expected_matches = self.expected_names.iter().any(|expected| {
683                    info.common_name.as_deref() == Some(expected)
684                        || info.subject_alt_names.iter().any(|san| san == expected)
685                });
686
687                if !expected_matches {
688                    return Err(NetError::InvalidCertificate(format!(
689                        "Server name '{}' does not match certificate",
690                        name
691                    )));
692                }
693            }
694        }
695
696        Ok(())
697    }
698}
699
700impl ServerCertVerifier for MtlsServerVerifier {
701    fn verify_server_cert(
702        &self,
703        end_entity: &CertificateDer<'_>,
704        _intermediates: &[CertificateDer<'_>],
705        server_name: &ServerName<'_>,
706        _ocsp_response: &[u8],
707        _now: UnixTime,
708    ) -> Result<ServerCertVerified, rustls::Error> {
709        let name_str = match server_name {
710            ServerName::DnsName(name) => Some(name.as_ref().to_string()),
711            ServerName::IpAddress(ip) => Some(format!("{:?}", ip)),
712            _ => None,
713        };
714
715        match self.verify_certificate(end_entity, name_str.as_deref()) {
716            Ok(()) => {
717                debug!("Server certificate verified");
718                Ok(ServerCertVerified::assertion())
719            }
720            Err(e) => {
721                error!(error = %e, "Server certificate verification failed");
722                Err(rustls::Error::InvalidCertificate(
723                    rustls::CertificateError::BadEncoding,
724                ))
725            }
726        }
727    }
728
729    fn verify_tls12_signature(
730        &self,
731        _message: &[u8],
732        _cert: &CertificateDer<'_>,
733        _dss: &DigitallySignedStruct,
734    ) -> Result<HandshakeSignatureValid, rustls::Error> {
735        Ok(HandshakeSignatureValid::assertion())
736    }
737
738    fn verify_tls13_signature(
739        &self,
740        _message: &[u8],
741        _cert: &CertificateDer<'_>,
742        _dss: &DigitallySignedStruct,
743    ) -> Result<HandshakeSignatureValid, rustls::Error> {
744        Ok(HandshakeSignatureValid::assertion())
745    }
746
747    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
748        vec![
749            SignatureScheme::RSA_PKCS1_SHA256,
750            SignatureScheme::RSA_PKCS1_SHA384,
751            SignatureScheme::RSA_PKCS1_SHA512,
752            SignatureScheme::ECDSA_NISTP256_SHA256,
753            SignatureScheme::ECDSA_NISTP384_SHA384,
754            SignatureScheme::ECDSA_NISTP521_SHA512,
755            SignatureScheme::ED25519,
756        ]
757    }
758}
759
760/// mTLS configuration builder
761pub struct MtlsConfigBuilder {
762    /// Server certificate chain
763    cert_chain: Vec<CertificateDer<'static>>,
764    /// Server private key
765    private_key: Option<PrivateKeyDer<'static>>,
766    /// Root certificate store for client verification
767    client_roots: RootCertStore,
768    /// Root certificate store for server verification
769    server_roots: RootCertStore,
770    /// Whether client authentication is required
771    require_client_auth: bool,
772    /// Principal mapper
773    mapper: Arc<dyn PrincipalMapper>,
774    /// Revocation checker
775    revocation: Option<Arc<dyn RevocationChecker>>,
776    /// Hot reloadable certificates
777    hot_reload: Option<Arc<HotReloadableCertificates>>,
778}
779
780impl std::fmt::Debug for MtlsConfigBuilder {
781    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
782        f.debug_struct("MtlsConfigBuilder")
783            .field("cert_chain", &format!("<{} certs>", self.cert_chain.len()))
784            .field("private_key", &self.private_key.as_ref().map(|_| "<key>"))
785            .field("client_roots", &"<RootCertStore>")
786            .field("server_roots", &"<RootCertStore>")
787            .field("require_client_auth", &self.require_client_auth)
788            .field("mapper", &"<PrincipalMapper>")
789            .field(
790                "revocation",
791                &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
792            )
793            .field(
794                "hot_reload",
795                &self.hot_reload.as_ref().map(|_| "<HotReloadable>"),
796            )
797            .finish()
798    }
799}
800
801impl Default for MtlsConfigBuilder {
802    fn default() -> Self {
803        Self::new()
804    }
805}
806
807impl MtlsConfigBuilder {
808    /// Create a new mTLS configuration builder
809    pub fn new() -> Self {
810        Self {
811            cert_chain: Vec::new(),
812            private_key: None,
813            client_roots: RootCertStore::empty(),
814            server_roots: RootCertStore::empty(),
815            require_client_auth: true,
816            mapper: Arc::new(DefaultPrincipalMapper),
817            revocation: None,
818            hot_reload: None,
819        }
820    }
821
822    /// Set the server identity (certificate chain and private key)
823    pub fn with_identity(
824        mut self,
825        cert_chain: Vec<CertificateDer<'static>>,
826        private_key: PrivateKeyDer<'static>,
827    ) -> Self {
828        self.cert_chain = cert_chain;
829        self.private_key = Some(private_key);
830        self
831    }
832
833    /// Load server identity from PEM files
834    pub fn with_identity_files<P: AsRef<Path>>(
835        mut self,
836        cert_path: P,
837        key_path: P,
838    ) -> NetResult<Self> {
839        let loader = CertificateLoader::new();
840        let key_loader = crate::tls::PrivateKeyLoader::new();
841
842        self.cert_chain = loader.load_pem_file(cert_path)?;
843        self.private_key = Some(key_loader.load_pem_file(key_path)?);
844
845        Ok(self)
846    }
847
848    /// Add client CA certificate for verification
849    pub fn with_client_ca(mut self, cert: CertificateDer<'static>) -> NetResult<Self> {
850        self.client_roots
851            .add(cert)
852            .map_err(|e| NetError::InvalidCertificate(format!("Failed to add client CA: {e}")))?;
853        Ok(self)
854    }
855
856    /// Add client CA certificates from a store
857    pub fn with_client_ca_store(mut self, store: &CertificateStore) -> Self {
858        let roots = store.get_root_store();
859        self.client_roots.extend(roots.roots.iter().cloned());
860        self
861    }
862
863    /// Add server CA certificate for verification
864    pub fn with_server_ca(mut self, cert: CertificateDer<'static>) -> NetResult<Self> {
865        self.server_roots
866            .add(cert)
867            .map_err(|e| NetError::InvalidCertificate(format!("Failed to add server CA: {e}")))?;
868        Ok(self)
869    }
870
871    /// Add server CA certificates from a store
872    pub fn with_server_ca_store(mut self, store: &CertificateStore) -> Self {
873        let roots = store.get_root_store();
874        self.server_roots.extend(roots.roots.iter().cloned());
875        self
876    }
877
878    /// Add system root certificates for server verification
879    pub fn with_system_roots(mut self) -> Self {
880        self.server_roots
881            .extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
882        self
883    }
884
885    /// Set whether client authentication is required
886    pub fn require_client_auth(mut self, required: bool) -> Self {
887        self.require_client_auth = required;
888        self
889    }
890
891    /// Set the principal mapper
892    pub fn with_mapper(mut self, mapper: Arc<dyn PrincipalMapper>) -> Self {
893        self.mapper = mapper;
894        self
895    }
896
897    /// Set the revocation checker
898    pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
899        self.revocation = Some(checker);
900        self
901    }
902
903    /// Enable hot reload support
904    pub fn with_hot_reload(mut self, hot_reload: Arc<HotReloadableCertificates>) -> Self {
905        self.hot_reload = Some(hot_reload);
906        self
907    }
908
909    /// Build the server configuration
910    pub fn build_server_config(self) -> NetResult<ServerConfig> {
911        let private_key = self
912            .private_key
913            .ok_or_else(|| NetError::InvalidCertificate("Private key is required".to_string()))?;
914
915        if self.cert_chain.is_empty() {
916            return Err(NetError::InvalidCertificate(
917                "Certificate chain is required".to_string(),
918            ));
919        }
920
921        // Create client verifier
922        let client_verifier =
923            Arc::new(MtlsClientVerifier::new(self.client_roots).with_mapper(self.mapper));
924
925        let config = if self.require_client_auth {
926            ServerConfig::builder()
927                .with_client_cert_verifier(client_verifier)
928                .with_single_cert(self.cert_chain, private_key)
929                .map_err(|e| {
930                    NetError::InvalidCertificate(format!("Failed to build server config: {e}"))
931                })?
932        } else {
933            ServerConfig::builder()
934                .with_no_client_auth()
935                .with_single_cert(self.cert_chain, private_key)
936                .map_err(|e| {
937                    NetError::InvalidCertificate(format!("Failed to build server config: {e}"))
938                })?
939        };
940
941        Ok(config)
942    }
943
944    /// Build the client configuration
945    pub fn build_client_config(self) -> NetResult<ClientConfig> {
946        let private_key = self.private_key.ok_or_else(|| {
947            NetError::InvalidCertificate("Private key is required for client mTLS".to_string())
948        })?;
949
950        if self.cert_chain.is_empty() {
951            return Err(NetError::InvalidCertificate(
952                "Certificate chain is required for client mTLS".to_string(),
953            ));
954        }
955
956        // Create server verifier
957        let server_verifier = Arc::new(MtlsServerVerifier::new(self.server_roots));
958
959        let config = ClientConfig::builder()
960            .dangerous()
961            .with_custom_certificate_verifier(server_verifier)
962            .with_client_auth_cert(self.cert_chain, private_key)
963            .map_err(|e| {
964                NetError::InvalidCertificate(format!("Failed to build client config: {e}"))
965            })?;
966
967        Ok(config)
968    }
969
970    /// Build TLS acceptor for server
971    pub fn build_acceptor(self) -> NetResult<TlsAcceptor> {
972        let config = self.build_server_config()?;
973        Ok(TlsAcceptor::from(Arc::new(config)))
974    }
975
976    /// Build TLS connector for client
977    pub fn build_connector(self) -> NetResult<TlsConnector> {
978        let config = self.build_client_config()?;
979        Ok(TlsConnector::from(Arc::new(config)))
980    }
981}
982
983/// mTLS server helper
984pub struct MtlsServer {
985    /// TLS acceptor
986    acceptor: TlsAcceptor,
987    /// Hot reload handle
988    hot_reload: Option<Arc<HotReloadableCertificates>>,
989}
990
991impl std::fmt::Debug for MtlsServer {
992    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
993        f.debug_struct("MtlsServer")
994            .field("has_hot_reload", &self.hot_reload.is_some())
995            .finish()
996    }
997}
998
999impl MtlsServer {
1000    /// Create a new mTLS configuration builder
1001    pub fn builder() -> MtlsConfigBuilder {
1002        MtlsConfigBuilder::new()
1003    }
1004
1005    /// Create from pre-built config
1006    pub fn from_config(config: ServerConfig) -> Self {
1007        Self {
1008            acceptor: TlsAcceptor::from(Arc::new(config)),
1009            hot_reload: None,
1010        }
1011    }
1012
1013    /// Get the TLS acceptor
1014    pub fn acceptor(&self) -> &TlsAcceptor {
1015        &self.acceptor
1016    }
1017
1018    /// Enable hot reload support
1019    pub fn with_hot_reload(mut self, hot_reload: Arc<HotReloadableCertificates>) -> Self {
1020        self.hot_reload = Some(hot_reload);
1021        self
1022    }
1023}
1024
1025/// mTLS client helper
1026pub struct MtlsClient {
1027    /// TLS connector
1028    connector: TlsConnector,
1029}
1030
1031impl std::fmt::Debug for MtlsClient {
1032    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1033        f.debug_struct("MtlsClient").finish()
1034    }
1035}
1036
1037impl MtlsClient {
1038    /// Create a new mTLS configuration builder
1039    pub fn builder() -> MtlsConfigBuilder {
1040        MtlsConfigBuilder::new()
1041    }
1042
1043    /// Create from pre-built config
1044    pub fn from_config(config: ClientConfig) -> Self {
1045        Self {
1046            connector: TlsConnector::from(Arc::new(config)),
1047        }
1048    }
1049
1050    /// Get the TLS connector
1051    pub fn connector(&self) -> &TlsConnector {
1052        &self.connector
1053    }
1054}
1055
1056/// Mutual authentication handshake result
1057#[derive(Debug, Clone)]
1058pub struct HandshakeResult {
1059    /// Peer principal (for server, this is the client; for client, this is the server)
1060    pub peer_principal: Option<Principal>,
1061    /// Negotiated TLS version
1062    pub tls_version: String,
1063    /// Negotiated cipher suite
1064    pub cipher_suite: String,
1065    /// Handshake duration
1066    pub duration: Duration,
1067}
1068
1069impl HandshakeResult {
1070    /// Check if peer authentication was successful
1071    pub fn is_authenticated(&self) -> bool {
1072        self.peer_principal.is_some()
1073    }
1074
1075    /// Get peer principal name
1076    pub fn peer_name(&self) -> Option<&str> {
1077        self.peer_principal.as_ref().map(|p| p.name.as_str())
1078    }
1079}
1080
1081#[cfg(test)]
1082mod tests {
1083    use super::*;
1084    use crate::tls::SelfSignedGenerator;
1085
1086    #[test]
1087    fn test_principal_from_certificate() {
1088        // Generate a test certificate
1089        let generator = SelfSignedGenerator::new("test-user").with_organization("Test Org");
1090
1091        let (cert, _) = generator.generate().expect("Should generate certificate");
1092
1093        let principal = Principal::from_certificate(&cert).expect("Should create principal");
1094
1095        assert_eq!(principal.name, "test-user");
1096        assert_eq!(principal.organization.as_deref(), Some("Test Org"));
1097        assert!(!principal.fingerprint.is_empty());
1098    }
1099
1100    #[test]
1101    fn test_default_principal_mapper() {
1102        let generator = SelfSignedGenerator::new("test-user");
1103        let (cert, _) = generator.generate().expect("Should generate certificate");
1104
1105        let mapper = DefaultPrincipalMapper;
1106        let principal = mapper
1107            .map_certificate(&cert)
1108            .expect("Should map certificate");
1109        let name = mapper.get_principal_name(&principal);
1110
1111        assert_eq!(name, "test-user");
1112    }
1113
1114    #[test]
1115    fn test_organization_principal_mapper() {
1116        let generator = SelfSignedGenerator::new("test-user").with_organization("Test Org");
1117
1118        let (cert, _) = generator.generate().expect("Should generate certificate");
1119
1120        let mapper = OrganizationPrincipalMapper;
1121        let principal = mapper
1122            .map_certificate(&cert)
1123            .expect("Should map certificate");
1124        let name = mapper.get_principal_name(&principal);
1125
1126        assert_eq!(name, "Test Org/test-user");
1127    }
1128
1129    #[test]
1130    fn test_crl_revocation_checker() {
1131        let checker = CrlRevocationChecker::new();
1132
1133        // Add a revoked serial
1134        checker.add_revoked("abc123");
1135
1136        assert!(checker.is_revoked("abc123"));
1137        assert!(!checker.is_revoked("def456"));
1138        assert_eq!(checker.revoked_count(), 1);
1139    }
1140
1141    #[test]
1142    fn test_mtls_config_builder() {
1143        // Install CryptoProvider for rustls
1144        rustls::crypto::ring::default_provider()
1145            .install_default()
1146            .ok();
1147
1148        // Generate CA certificate
1149        let ca_generator = SelfSignedGenerator::new("Test CA")
1150            .as_ca()
1151            .with_validity_days(365);
1152
1153        let (ca_cert, _ca_key) = ca_generator.generate().expect("Should generate CA");
1154
1155        // Generate server certificate
1156        let server_generator = SelfSignedGenerator::new("localhost").with_san("127.0.0.1");
1157
1158        let (server_cert, server_key) = server_generator
1159            .generate()
1160            .expect("Should generate server cert");
1161
1162        // Build server config
1163        let result = MtlsConfigBuilder::new()
1164            .with_identity(vec![server_cert.clone()], server_key.clone_key())
1165            .with_client_ca(ca_cert.clone())
1166            .expect("Should add CA")
1167            .require_client_auth(true)
1168            .build_server_config();
1169
1170        assert!(result.is_ok());
1171    }
1172
1173    #[test]
1174    fn test_mtls_client_verifier() {
1175        // Generate CA and client certificates
1176        let ca_generator = SelfSignedGenerator::new("Test CA").as_ca();
1177
1178        let (ca_cert, _) = ca_generator.generate().expect("Should generate CA");
1179
1180        let client_generator =
1181            SelfSignedGenerator::new("test-client").with_organization("Test Org");
1182
1183        let (client_cert, _) = client_generator
1184            .generate()
1185            .expect("Should generate client cert");
1186
1187        // Create verifier
1188        let mut roots = RootCertStore::empty();
1189        roots.add(ca_cert).expect("Should add CA");
1190
1191        let verifier = MtlsClientVerifier::new(roots);
1192
1193        // Verify certificate (note: this is a self-signed cert, so chain verification would fail
1194        // in a real scenario, but our custom verifier focuses on other checks)
1195        let loader = CertificateLoader::new();
1196        let info = loader
1197            .get_certificate_info(&client_cert)
1198            .expect("Should get info");
1199
1200        assert_eq!(info.common_name.as_deref(), Some("test-client"));
1201    }
1202
1203    #[test]
1204    fn test_ocsp_revocation_checker_cache() {
1205        let checker = OcspRevocationChecker::new().with_cache_ttl(Duration::from_secs(3600));
1206
1207        // Cache should initially be empty
1208        let generator = SelfSignedGenerator::new("test");
1209        let (cert, _) = generator.generate().expect("Should generate cert");
1210
1211        // First check should return Unknown (no OCSP response cached)
1212        let status = checker
1213            .check_revocation(&cert)
1214            .expect("Should check revocation");
1215        assert_eq!(status, RevocationStatus::Unknown);
1216    }
1217
1218    #[test]
1219    fn test_combined_revocation_checker() {
1220        let crl = Arc::new(CrlRevocationChecker::new());
1221        let ocsp = Arc::new(OcspRevocationChecker::new());
1222
1223        let combined = CombinedRevocationChecker::new(crl.clone(), ocsp);
1224
1225        let generator = SelfSignedGenerator::new("test");
1226        let (cert, _) = generator.generate().expect("Should generate cert");
1227
1228        // Should return Unknown since neither has data
1229        let status = combined
1230            .check_revocation(&cert)
1231            .expect("Should check revocation");
1232        assert_eq!(status, RevocationStatus::Unknown);
1233    }
1234
1235    #[test]
1236    fn test_handshake_result() {
1237        let principal = Principal {
1238            name: "test-user".to_string(),
1239            organization: Some("Test Org".to_string()),
1240            organizational_unit: None,
1241            email: None,
1242            serial: "123abc".to_string(),
1243            fingerprint: "abc123".to_string(),
1244            attributes: HashMap::new(),
1245        };
1246
1247        let result = HandshakeResult {
1248            peer_principal: Some(principal),
1249            tls_version: "TLS 1.3".to_string(),
1250            cipher_suite: "TLS_AES_256_GCM_SHA384".to_string(),
1251            duration: Duration::from_millis(50),
1252        };
1253
1254        assert!(result.is_authenticated());
1255        assert_eq!(result.peer_name(), Some("test-user"));
1256    }
1257}